@@ -2654,6 +2654,67 @@ static void ggml_cann_mul_mat_id_fp(ggml_backend_cann_context& ctx, ggml_tensor*
26542654 memcpy (ori_src0_nb, cast_nb, sizeof (ori_src0_nb));
26552655 }
26562656
2657+ #ifdef ASCEND_310P
2658+ ggml_tensor src0_row = *src0;
2659+ ggml_tensor src1_row = *src1;
2660+ ggml_tensor dst_row = *dst;
2661+
2662+ if (src0->type == GGML_TYPE_F16) {
2663+ src0_row.type = GGML_TYPE_F32;
2664+ }
2665+
2666+ // src0_row [D, M, 1, 1] weight without permute
2667+ src0_row.ne [2 ] = 1 ;
2668+ src0_row.ne [3 ] = 1 ;
2669+ src0_row.nb [0 ] = ori_src0_nb[0 ];
2670+ src0_row.nb [1 ] = ori_src0_nb[1 ];
2671+ src0_row.nb [2 ] = ori_src0_nb[1 ];
2672+ src0_row.nb [3 ] = ori_src0_nb[1 ];
2673+
2674+ // src1_row [D, 1, 1, 1] -> input
2675+ src1_row.ne [1 ] = 1 ;
2676+ src1_row.ne [2 ] = 1 ;
2677+ src1_row.ne [3 ] = 1 ;
2678+ src1_row.nb [2 ] = nb11;
2679+ src1_row.nb [3 ] = nb11;
2680+
2681+ // dst_row [M, 1, 1, 1] -> out
2682+ dst_row.ne [1 ] = 1 ;
2683+ dst_row.ne [2 ] = 1 ;
2684+ dst_row.ne [3 ] = 1 ;
2685+ dst_row.nb [2 ] = nb1;
2686+ dst_row.nb [3 ] = nb1;
2687+
2688+ // create weight for one row
2689+ for (int64_t iid1 = 0 ; iid1 < ids->ne [1 ]; iid1++) {
2690+ for (int64_t id = 0 ; id < n_ids; id++) {
2691+ // expert index
2692+ int32_t i02 = *(int32_t *) (ids_host.data () + iid1*ids->nb [1 ] + id*ids->nb [0 ]);
2693+ GGML_ASSERT (i02 >= 0 && i02 < n_as);
2694+
2695+ // If B = 1 (broadcast), always use 0; otherwise, use id.
2696+ int64_t i11 = (ne11 == 1 ? 0 : id);
2697+ int64_t i12 = iid1;
2698+
2699+ int64_t i1 = id;
2700+ int64_t i2 = i12;
2701+
2702+ void * src0_tmp_ptr = src0_original + i02*ori_src0_nb[2 ];
2703+ void * src1_tmp_ptr = src1_original + i11*nb11 + i12*nb12;
2704+ void * dst_tmp_ptr = dst_original + i1*nb1 + i2*nb2;
2705+
2706+ src0_row.data = src0_tmp_ptr;
2707+ src1_row.data = src1_tmp_ptr;
2708+ dst_row.data = dst_tmp_ptr;
2709+ dst_row.src [0 ] = &src0_row;
2710+ dst_row.src [1 ] = &src1_row;
2711+
2712+ ggml_cann_mul_mat (ctx, &dst_row);
2713+ }
2714+ }
2715+ return ;
2716+ #endif
2717+
26572718 std::vector<aclTensor*> src0_tensor_vec;
26582719 std::vector<aclTensor*> src1_tensor_vec;
26592720 std::vector<aclTensor*> dst_tensor_vec;
0 commit comments