6565#include < aclnnop/aclnn_eq_tensor.h>
6666#include < aclnnop/aclnn_gt_scalar.h>
6767#include < aclnnop/aclnn_pow.h>
68+ #include < aclnnop/aclnn_grouped_matmul_v2.h>
6869#include < float.h>
6970
7071#include < cmath>
@@ -83,6 +84,12 @@ struct mmid_row_mapping {
8384 int32_t i2;
8485};
8586
87+ struct expert_mapping {
88+ std::vector<mmid_row_mapping> row_mappings;
89+ int64_t num_src1_rows;
90+ int64_t offset;
91+ };
92+
8693void bcast_shape (ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst, aclTensor ** acl_src0,
8794 aclTensor ** acl_src1, aclTensor ** acl_dst) {
8895 GGML_ASSERT (ggml_are_same_shape (src0, dst) && ggml_can_repeat (src1, src0));
@@ -2593,14 +2600,15 @@ void ggml_cann_step(ggml_backend_cann_context& ctx, ggml_tensor* dst){
25932600 ggml_cann_release_resources (ctx, acl_src, acl_dst, alpha);
25942601}
25952602
2596- void ggml_cann_mul_mat_id (ggml_backend_cann_context& ctx, ggml_tensor* dst) {
2603+ void ggml_cann_mul_mat_id_fp (ggml_backend_cann_context& ctx, ggml_tensor* dst) {
25972604 // dst [M, K, N, 1]
25982605 ggml_tensor * src0 = dst->src [0 ]; // src0 [D, M, A, 1]
25992606 ggml_tensor * src1 = dst->src [1 ]; // src1 [D, B, N, 1], B = K or B = 1
26002607 ggml_tensor * ids = dst->src [2 ]; // ids [K, N]
26012608
26022609 GGML_TENSOR_BINARY_OP_LOCALS
26032610
2611+ // copy index from npu to cpu
26042612 int64_t n_as = ne02; // A
26052613 int64_t n_ids = ids->ne [0 ]; // K
26062614
@@ -2613,36 +2621,49 @@ void ggml_cann_mul_mat_id(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
26132621 char * src0_original = (char *) src0->data ;
26142622 char * src1_original = (char *) src1->data ;
26152623 char * dst_original = (char *) dst->data ;
2624+ size_t ori_src0_nb[4 ] = {nb00, nb01, nb02, nb03};
2625+
2626+ // src0 is F16, src1 is F32, dst is F32
2627+ ggml_cann_pool_alloc src0_cast_allocator;
2628+ if (src0->type == GGML_TYPE_F16) {
2629+ src0_cast_allocator.alloc (ctx.pool (), sizeof (float ) * ggml_nelements (src0));
2630+ void * src0_cast_buf = src0_cast_allocator.get ();
2631+
2632+ size_t cast_nb[GGML_MAX_DIMS];
2633+ cast_nb[0 ] = sizeof (float_t );
2634+ for (int i = 1 ; i < GGML_MAX_DIMS; i++) {
2635+ cast_nb[i] = cast_nb[i - 1 ] * src0->ne [i - 1 ];
2636+ }
2637+
2638+ aclTensor* acl_src0_f16 = ggml_cann_create_tensor (src0);
2639+ aclTensor* acl_cast = ggml_cann_create_tensor (src0_cast_buf,
2640+ ACL_FLOAT, sizeof (float ), src0->ne , cast_nb, 4 );
2641+ GGML_CANN_CALL_ACLNN_OP (ctx, Cast, acl_src0_f16, ACL_FLOAT, acl_cast);
2642+ ggml_cann_release_resources (ctx, acl_cast, acl_src0_f16);
26162643
2617- ggml_tensor src0_row = *src0;
2618- ggml_tensor src1_row = *src1;
2619- ggml_tensor dst_row = *dst;
2620-
2621- // src0_row [D, M, 1, 1]
2622- src0_row.ne [2 ] = 1 ;
2623- src0_row.ne [3 ] = 1 ;
2624- src0_row.nb [3 ] = nb02;
2625-
2626- // src1_row [D, 1, 1, 1]
2627- src1_row.ne [1 ] = 1 ;
2628- src1_row.ne [2 ] = 1 ;
2629- src1_row.ne [3 ] = 1 ;
2630- src1_row.nb [2 ] = nb11;
2631- src1_row.nb [3 ] = nb11;
2632-
2633- // dst_row [D, 1, 1, 1]
2634- dst_row.ne [1 ] = 1 ;
2635- dst_row.ne [2 ] = 1 ;
2636- dst_row.ne [3 ] = 1 ;
2637- dst_row.nb [2 ] = nb1;
2638- dst_row.nb [3 ] = nb1;
2644+ src0_original = (char *) src0_cast_buf;
2645+ memcpy (ori_src0_nb, cast_nb, sizeof (ori_src0_nb));
2646+ }
26392647
2648+ std::vector<aclTensor*> src0_tensor_vec;
2649+ std::vector<aclTensor*> src1_tensor_vec;
2650+ std::vector<aclTensor*> dst_tensor_vec;
26402651 // ne12 == ids->ne[1] == N
26412652 if (ne12 == 1 ) {
26422653 for (int64_t iid1 = 0 ; iid1 < ids->ne [1 ]; iid1++) {
26432654 for (int64_t id = 0 ; id < n_ids; id++) {
2655+ // src0_row [M, D] -> weight && permute
2656+ int64_t src0_ne[2 ] = {ne01, ne00};
2657+ size_t src0_nb[2 ] = {ori_src0_nb[1 ], ori_src0_nb[0 ]};
2658+ // src1_row [D, 1] -> input
2659+ int64_t src1_ne[2 ] = {ne10, 1 };
2660+ size_t src1_nb[2 ] = {nb10, nb11};
2661+ // dst_row [M, 1] -> out
2662+ int64_t dst_ne[2 ] = {ne0, 1 };
2663+ size_t dst_nb[2 ] = {nb0, nb1};
2664+
2665+ // expert index
26442666 int32_t i02 = *(int32_t *) (ids_host.data () + iid1*ids->nb [1 ] + id*ids->nb [0 ]);
2645-
26462667 GGML_ASSERT (i02 >= 0 && i02 < n_as);
26472668
26482669 // If B = 1 (broadcast), always use 0; otherwise, use id.
@@ -2652,30 +2673,51 @@ void ggml_cann_mul_mat_id(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
26522673 int64_t i1 = id;
26532674 int64_t i2 = i12;
26542675
2655- src0_row.data = src0_original + i02*nb02;
2656- src1_row.data = src1_original + i11*nb11 + i12*nb12;
2657- dst_row.data = dst_original + i1*nb1 + i2*nb2;
2658- dst_row.src [0 ] = &src0_row;
2659- dst_row.src [1 ] = &src1_row;
2660- ggml_cann_mul_mat (ctx, &dst_row);
2676+ void * src0_tmp_ptr = src0_original + i02*ori_src0_nb[2 ];
2677+ void * src1_tmp_ptr = src1_original + i11*nb11 + i12*nb12;
2678+ void * dst_tmp_ptr = dst_original + i1*nb1 + i2*nb2;
2679+
2680+ aclTensor* acl_src0 = ggml_cann_create_tensor (src0_tmp_ptr,
2681+ ACL_FLOAT, sizeof (float ),
2682+ src0_ne, src0_nb, 2 );
2683+ aclTensor* acl_src1 = ggml_cann_create_tensor (src1_tmp_ptr,
2684+ ACL_FLOAT, sizeof (float ),
2685+ src1_ne, src1_nb, 2 );
2686+ aclTensor* acl_dst = ggml_cann_create_tensor (dst_tmp_ptr,
2687+ ACL_FLOAT, sizeof (float ),
2688+ dst_ne, dst_nb, 2 );
2689+
2690+ src0_tensor_vec.push_back (acl_src0);
2691+ src1_tensor_vec.push_back (acl_src1);
2692+ dst_tensor_vec.push_back (acl_dst);
26612693 }
26622694 }
2695+ aclTensorList* src0_tensor_list = aclCreateTensorList (src0_tensor_vec.data (), src0_tensor_vec.size ());
2696+ aclTensorList* src1_tensor_list = aclCreateTensorList (src1_tensor_vec.data (), src1_tensor_vec.size ());
2697+ aclTensorList* dst_tensor_list = aclCreateTensorList (dst_tensor_vec.data (), dst_tensor_vec.size ());
2698+
2699+ GGML_CANN_CALL_ACLNN_OP (ctx, GroupedMatmulV2, src1_tensor_list, src0_tensor_list,
2700+ nullptr , nullptr , nullptr , nullptr , nullptr , nullptr , 0 , -1 , dst_tensor_list);
2701+
2702+ ggml_cann_release_resources (ctx, src0_tensor_list, src1_tensor_list, dst_tensor_list);
26632703 } else {
26642704 ggml_cann_pool_alloc src1_cont_allocator (
2665- ctx.pool (),sizeof (float ) * ggml_nelements (src1));
2705+ ctx.pool (),sizeof (float ) * ggml_nelements (src1) / ne11 * ids-> ne [ 0 ] );
26662706 ggml_cann_pool_alloc dst_cont_allocator (
26672707 ctx.pool (), sizeof (float ) * ggml_nelements (dst));
26682708
26692709 void * src1_cont_buf = src1_cont_allocator.get ();
26702710 void * dst_cont_buf = dst_cont_allocator.get ();
26712711
2672- src1_row.data = src1_cont_buf;
2673- dst_row.data = dst_cont_buf;
2674-
2712+ std::vector<expert_mapping> expert_mappings;
2713+ int64_t total_num_src1_rows = 0 ;
26752714 for (int64_t i02 = 0 ; i02 < n_as; i02++) {
26762715 std::vector<mmid_row_mapping> row_mappings;
26772716 int64_t num_src1_rows = 0 ;
26782717
2718+ void * src0_tmp_ptr = (char *)src0_original + i02*ori_src0_nb[2 ];
2719+ void * src1_tmp_ptr = (char *)src1_cont_buf + total_num_src1_rows * nb11;
2720+ void * dst_tmp_ptr = (char *)dst_cont_buf + total_num_src1_rows * nb1;
26792721 for (int64_t iid1 = 0 ; iid1 < ids->ne [1 ]; iid1++) {
26802722 for (int64_t id = 0 ; id < n_ids; id++) {
26812723 int32_t row_id_i = *(int32_t *) (ids_host.data () + iid1*ids->nb [1 ] + id*ids->nb [0 ]);
@@ -2686,54 +2728,91 @@ void ggml_cann_mul_mat_id(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
26862728 mapping.i1 = static_cast <int32_t >(id);
26872729 mapping.i2 = static_cast <int32_t >(iid1);
26882730 row_mappings.push_back (mapping);
2689- num_src1_rows++;
26902731
26912732 int64_t read_b = (ne11 == 1 ? 0 : id);
26922733 char * src_ptr = src1_original
26932734 + read_b * nb11
26942735 + mapping.i2 * nb12;
2695- char * dst_ptr = (char *)src1_cont_buf + (num_src1_rows - 1 ) * nb11;
2696- ACL_CHECK (aclrtMemcpyAsync (dst_ptr, nb11, src_ptr, nb11,
2697- ACL_MEMCPY_DEVICE_TO_DEVICE, ctx.stream ()));
2736+ char * dst_ptr = (char *)src1_cont_buf + total_num_src1_rows * nb11;
2737+ ggml_cann_async_memcpy (ctx, dst_ptr, src_ptr, nb11,
2738+ ACL_MEMCPY_DEVICE_TO_DEVICE);
2739+
2740+ num_src1_rows++;
2741+ total_num_src1_rows++;
26982742 }
26992743 }
27002744 }
2745+
2746+ // expert_map index is expert index
2747+ expert_mapping expert_map;
2748+ expert_map.row_mappings = row_mappings;
2749+ expert_map.num_src1_rows = num_src1_rows;
2750+ expert_map.offset = total_num_src1_rows - num_src1_rows;
2751+ expert_mappings.push_back (expert_map);
27012752
27022753 if (num_src1_rows == 0 ) {
27032754 continue ;
27042755 }
27052756
2706- // src0_row [D, M, 1, 1]
2707- src0_row.data = src0_original + i02 * nb02;
2708-
2709- // src1_row [D, The number of values in K * N is i02, 1, 1]
2710- src1_row.ne [1 ] = num_src1_rows;
2711- src1_row.nb [1 ] = nb11;
2712- src1_row.nb [2 ] = num_src1_rows * nb11;
2713- src1_row.nb [3 ] = num_src1_rows * nb11;
2714-
2715- // dst_row [D, The number of values in K * N is i02, 1, 1]
2716- dst_row.ne [1 ] = num_src1_rows;
2717- dst_row.nb [1 ] = nb1;
2718- dst_row.nb [2 ] = num_src1_rows * nb1;
2719- dst_row.nb [3 ] = num_src1_rows * nb1;
2720-
2721- dst_row.src [0 ] = &src0_row;
2722- dst_row.src [1 ] = &src1_row;
2757+ // src0_row [M, D] -> weight && permute
2758+ int64_t src0_ne[2 ] = {ne01, ne00};
2759+ size_t src0_nb[2 ] = {ori_src0_nb[1 ], ori_src0_nb[0 ]};
2760+ // src1_row [D, num_src1_rows] -> input
2761+ int64_t src1_ne[2 ] = {ne10, num_src1_rows};
2762+ size_t src1_nb[2 ] = {nb10, nb11};
2763+ // dst_row [M, num_src1_rows] -> out
2764+ int64_t dst_ne[2 ] = {ne0, num_src1_rows};
2765+ size_t dst_nb[2 ] = {nb0, nb1};
2766+
2767+ aclTensor* acl_src0 = ggml_cann_create_tensor (src0_tmp_ptr,
2768+ ACL_FLOAT, sizeof (float ),
2769+ src0_ne, src0_nb, 2 );
2770+ aclTensor* acl_src1 = ggml_cann_create_tensor (src1_tmp_ptr,
2771+ ACL_FLOAT, sizeof (float ),
2772+ src1_ne, src1_nb, 2 );
2773+ aclTensor* acl_dst = ggml_cann_create_tensor (dst_tmp_ptr,
2774+ ACL_FLOAT, sizeof (float ),
2775+ dst_ne, dst_nb, 2 );
2776+
2777+ src0_tensor_vec.push_back (acl_src0);
2778+ src1_tensor_vec.push_back (acl_src1);
2779+ dst_tensor_vec.push_back (acl_dst);
2780+ }
2781+ aclTensorList* src0_tensor_list = aclCreateTensorList (src0_tensor_vec.data (), src0_tensor_vec.size ());
2782+ aclTensorList* src1_tensor_list = aclCreateTensorList (src1_tensor_vec.data (), src1_tensor_vec.size ());
2783+ aclTensorList* dst_tensor_list = aclCreateTensorList (dst_tensor_vec.data (), dst_tensor_vec.size ());
27232784
2724- ggml_cann_mul_mat (ctx, &dst_row);
2785+ GGML_CANN_CALL_ACLNN_OP (ctx, GroupedMatmulV2, src1_tensor_list, src0_tensor_list,
2786+ nullptr , nullptr , nullptr , nullptr , nullptr , nullptr , 0 , -1 , dst_tensor_list);
27252787
2726- for (int64_t i = 0 ; i < num_src1_rows; ++i) {
2727- int64_t i1 = row_mappings[i].i1 ;
2728- int64_t i2 = row_mappings[i].i2 ;
2788+ for (size_t i = 0 ; i < expert_mappings.size (); ++i) {
2789+ expert_mapping expert_map = expert_mappings[i];
2790+ for (int64_t j = 0 ; j < expert_map.num_src1_rows ; ++j) {
2791+ int64_t i1 = expert_map.row_mappings [j].i1 ;
2792+ int64_t i2 = expert_map.row_mappings [j].i2 ;
27292793
2730- char * src_ptr = (char *)dst_cont_buf + i * nb1;
2794+ char * src_ptr = (char *)dst_cont_buf + expert_map. offset * nb1 + j * nb1;
27312795 char * dst_ptr = dst_original + i1 * nb1 + i2 * nb2;
27322796
2733- ACL_CHECK ( aclrtMemcpyAsync (dst_ptr, nb1 , src_ptr, nb1,
2734- ACL_MEMCPY_DEVICE_TO_DEVICE, ctx. stream ()) );
2797+ ggml_cann_async_memcpy (ctx, dst_ptr , src_ptr, nb1,
2798+ ACL_MEMCPY_DEVICE_TO_DEVICE);
27352799 }
2736- }
2800+
2801+ }
2802+ ggml_cann_release_resources (ctx, src0_tensor_list, src1_tensor_list, dst_tensor_list);
27372803 }
27382804 return ;
27392805}
2806+
2807+ void ggml_cann_mul_mat_id (ggml_backend_cann_context& ctx, ggml_tensor* dst) {
2808+ const enum ggml_type type = dst->src [0 ]->type ;
2809+ switch (type) {
2810+ case GGML_TYPE_F32:
2811+ case GGML_TYPE_F16:
2812+ ggml_cann_mul_mat_id_fp (ctx, dst);
2813+ break ;
2814+ default :
2815+ GGML_ABORT (" Unsupported type for mul_mat_id" );
2816+ break ;
2817+ }
2818+ }
0 commit comments