6565#include  < aclnnop/aclnn_eq_tensor.h> 
6666#include  < aclnnop/aclnn_gt_scalar.h> 
6767#include  < aclnnop/aclnn_pow.h> 
68- #include  < aclnnop/aclnn_grouped_matmul_v2 .h> 
68+ #include  < aclnnop/aclnn_grouped_matmul_v3 .h> 
6969#include  < aclnnop/aclnn_fused_infer_attention_score_v2.h> 
7070#include  < float.h> 
7171
@@ -2654,6 +2654,67 @@ static void ggml_cann_mul_mat_id_fp(ggml_backend_cann_context& ctx, ggml_tensor*
26542654        memcpy (ori_src0_nb, cast_nb, sizeof (ori_src0_nb));
26552655    }
26562656
2657+ #ifdef  ASCEND_310P
2658+     ggml_tensor src0_row = *src0;
2659+     ggml_tensor src1_row = *src1;
2660+     ggml_tensor dst_row = *dst;
2661+ 
2662+     if  (src0->type  == GGML_TYPE_F16) {
2663+         src0_row.type  = GGML_TYPE_F32;
2664+     }
2665+ 
2666+     //  src0_row [D, M, 1, 1] weight without permute
2667+     src0_row.ne [2 ] = 1 ;
2668+     src0_row.ne [3 ] = 1 ;
2669+     src0_row.nb [0 ] = ori_src0_nb[0 ];
2670+     src0_row.nb [1 ] = ori_src0_nb[1 ];
2671+     src0_row.nb [2 ] = ori_src0_nb[1 ];
2672+     src0_row.nb [3 ] = ori_src0_nb[1 ];
2673+ 
2674+     //  src1_row [D, 1, 1, 1] -> input
2675+     src1_row.ne [1 ] = 1 ;
2676+     src1_row.ne [2 ] = 1 ;
2677+     src1_row.ne [3 ] = 1 ;
2678+     src1_row.nb [2 ] = nb11;
2679+     src1_row.nb [3 ] = nb11;
2680+ 
2681+     //  dst_row [M, 1, 1, 1] -> out
2682+     dst_row.ne [1 ] = 1 ;
2683+     dst_row.ne [2 ] = 1 ;
2684+     dst_row.ne [3 ] = 1 ;
2685+     dst_row.nb [2 ] = nb1;
2686+     dst_row.nb [3 ] = nb1;
2687+ 
2688+     // create weight for one row
2689+     for  (int64_t  iid1 = 0 ; iid1 < ids->ne [1 ]; iid1++) {
2690+         for  (int64_t  id = 0 ; id < n_ids; id++) {
2691+             //  expert index
2692+             int32_t  i02 = *(int32_t  *) (ids_host.data () + iid1*ids->nb [1 ] + id*ids->nb [0 ]);
2693+             GGML_ASSERT (i02 >= 0  && i02 < n_as);
2694+ 
2695+             //  If B = 1 (broadcast), always use 0; otherwise, use id.
2696+             int64_t  i11 = (ne11 == 1  ? 0  : id);
2697+             int64_t  i12 = iid1;
2698+ 
2699+             int64_t  i1 = id;
2700+             int64_t  i2 = i12;
2701+ 
2702+             void * src0_tmp_ptr = src0_original + i02*ori_src0_nb[2 ];
2703+             void * src1_tmp_ptr = src1_original + i11*nb11 + i12*nb12;
2704+             void * dst_tmp_ptr  = dst_original  + i1*nb1   + i2*nb2;
2705+ 
2706+             src0_row.data  = src0_tmp_ptr;
2707+             src1_row.data  = src1_tmp_ptr;
2708+             dst_row.data  = dst_tmp_ptr;
2709+             dst_row.src [0 ] = &src0_row;
2710+             dst_row.src [1 ] = &src1_row;
2711+ 
2712+             ggml_cann_mul_mat (ctx, &dst_row);
2713+         }
2714+     }
2715+     return ;
2716+ #endif 
2717+ 
26572718    std::vector<aclTensor*> src0_tensor_vec;
26582719    std::vector<aclTensor*> src1_tensor_vec;
26592720    std::vector<aclTensor*> dst_tensor_vec;
@@ -2701,9 +2762,9 @@ static void ggml_cann_mul_mat_id_fp(ggml_backend_cann_context& ctx, ggml_tensor*
27012762    }
27022763
27032764    size_t  GROUP_SIZE = 128 ;
2704-     //  GroupedMatmulV2  required tensor_list.size < 128
2765+     //  GroupedMatmulV3  required tensor_list.size < 128
27052766    for  (size_t  i = 0 ; i < src0_tensor_vec.size (); i += GROUP_SIZE) {
2706-         //  split and call GroupedMatmulV2 
2767+         //  split and call GroupedMatmulV3 
27072768        size_t  end = std::min (i + GROUP_SIZE, src0_tensor_vec.size ());
27082769        std::vector<aclTensor*> src0_tensor_vec_split (src0_tensor_vec.begin () + i, src0_tensor_vec.begin () + end);
27092770        std::vector<aclTensor*> src1_tensor_vec_split (src1_tensor_vec.begin () + i, src1_tensor_vec.begin () + end);
@@ -2713,7 +2774,7 @@ static void ggml_cann_mul_mat_id_fp(ggml_backend_cann_context& ctx, ggml_tensor*
27132774        aclTensorList* src1_tensor_list = aclCreateTensorList (src1_tensor_vec_split.data (), src1_tensor_vec_split.size ());
27142775        aclTensorList* dst_tensor_list = aclCreateTensorList (dst_tensor_vec_split.data (), dst_tensor_vec_split.size ());
27152776
2716-         GGML_CANN_CALL_ACLNN_OP (ctx, GroupedMatmulV2 , src1_tensor_list, src0_tensor_list,
2777+         GGML_CANN_CALL_ACLNN_OP (ctx, GroupedMatmulV3 , src1_tensor_list, src0_tensor_list,
27172778            nullptr , nullptr , nullptr , nullptr , nullptr , nullptr , 0 , -1 , dst_tensor_list);
27182779
27192780        ggml_cann_release_resources (ctx, src0_tensor_list, src1_tensor_list, dst_tensor_list);
0 commit comments