@@ -2700,14 +2700,10 @@ static void ggml_cann_mul_mat_id_fp(ggml_backend_cann_context& ctx, ggml_tensor*
27002700 }
27012701 }
27022702
2703- // GroupedMatmulV2 required tensor_list.size < 128
27042703 size_t GROUP_SIZE = 128 ;
2705- std::vector<std::vector<aclTensor*>> src0_tensor_vec_vec;
2706- std::vector<std::vector<aclTensor*>> src1_tensor_vec_vec;
2707- std::vector<std::vector<aclTensor*>> dst_tensor_vec_vec;
2708-
2709- // split and call GroupedMatmulV2
2704+ // GroupedMatmulV2 required tensor_list.size < 128
27102705 for (size_t i = 0 ; i < src0_tensor_vec.size (); i += GROUP_SIZE) {
2706+ // split and call GroupedMatmulV2
27112707 size_t end = std::min (i + GROUP_SIZE, src0_tensor_vec.size ());
27122708 std::vector<aclTensor*> src0_tensor_vec_split (src0_tensor_vec.begin () + i, src0_tensor_vec.begin () + end);
27132709 std::vector<aclTensor*> src1_tensor_vec_split (src1_tensor_vec.begin () + i, src1_tensor_vec.begin () + end);
@@ -2725,13 +2721,144 @@ static void ggml_cann_mul_mat_id_fp(ggml_backend_cann_context& ctx, ggml_tensor*
27252721 return ;
27262722}
27272723
2724+ /* *
2725+ * @brief Performs expert-specific matrix multiplication (MoE) with
2726+ * quantized precision using the CANN backend.
2727+ *
2728+ * This function executes a matrix multiplication operation tailored for
2729+ * Mixture of Experts (MoE) models, where the input tensor is multiplied
2730+ * with expert-specific quantized weight matrices. It leverages the CANN
2731+ * backend to perform efficient low-precision computations and stores the
2732+ * quantized result in the destination tensor `dst`.
2733+ *
2734+ * Quantization techniques reduce memory footprint and improve performance
2735+ * by using lower-bit representations (e.g., int8) instead of floating-point.
2736+ * This function is designed to work with such formats and may incorporate
2737+ * optimizations like identity-based fast paths or routing masks for sparse
2738+ * expert selection.
2739+ *
2740+ * @param ctx The context for executing CANN backend operations.
2741+ * @param dst The destination tensor where the quantized MoE multiplication result
2742+ * will be stored.
2743+ *
2744+ * @note This function assumes quantized data types and is designed for
2745+ * MoE architectures with potential sparse expert routing.
2746+ */
2747+ static void ggml_cann_mul_mat_id_quant (ggml_backend_cann_context& ctx, ggml_tensor* dst) {
2748+ // TODO: Use aclnnGroupedMatMul
2749+ // dst [M, K, N, 1]
2750+ ggml_tensor * src0 = dst->src [0 ]; // src0 [D, M, A, 1]
2751+ ggml_tensor * src1 = dst->src [1 ]; // src1 [D, B, N, 1], B = K or B = 1
2752+ ggml_tensor * ids = dst->src [2 ]; // ids [K, N]
2753+
2754+ GGML_TENSOR_BINARY_OP_LOCALS
2755+
2756+ // copy index from npu to cpu
2757+ int64_t n_as = ne02; // A
2758+ int64_t n_ids = ids->ne [0 ]; // K
2759+
2760+ std::vector<char > ids_host (ggml_nbytes (ids));
2761+ ggml_cann_async_memcpy (ctx, ids_host.data (), ids->data , ggml_nbytes (ids),
2762+ ACL_MEMCPY_DEVICE_TO_HOST);
2763+ ACL_CHECK (aclrtSynchronizeStream (ctx.stream ()));
2764+
2765+ char * src0_original = (char *) src0->data ;
2766+ char * src1_original = (char *) src1->data ;
2767+ char * dst_original = (char *) dst->data ;
2768+
2769+ ggml_tensor src0_row = *src0;
2770+ ggml_tensor src1_row = *src1;
2771+ ggml_tensor dst_row = *dst;
2772+
2773+ const enum ggml_type type = dst->src [0 ]->type ;
2774+ float weight_elem_size;
2775+ if (type == GGML_TYPE_Q4_0) {
2776+ weight_elem_size = float (sizeof (uint8_t )) / 2 ;
2777+ } else if (type == GGML_TYPE_Q8_0) {
2778+ weight_elem_size = float (sizeof (uint8_t ));
2779+ } else {
2780+ GGML_ABORT (" MUL_MAT_ID only support quant type Q4_0 and Q8_0 " );
2781+ }
2782+
2783+ // src0_row [D, M, 1, 1] weight without permute
2784+ src0_row.ne [2 ] = 1 ;
2785+ src0_row.ne [3 ] = 1 ;
2786+ src0_row.nb [0 ] = weight_elem_size;
2787+ src0_row.nb [1 ] = weight_elem_size * ne00;
2788+ src0_row.nb [2 ] = weight_elem_size * ne00;
2789+ src0_row.nb [3 ] = weight_elem_size * ne00;
2790+ size_t weight_stride = ne00 * ne01 * weight_elem_size;
2791+ size_t weight_size = weight_stride * ne02 * ne03;
2792+
2793+ // scale [D, M, 1, 1] -> scale && permute
2794+ size_t scale_elem_size = sizeof (uint16_t );
2795+ size_t scale_stride = src0->ne [1 ] * src0->ne [0 ] / QK8_0 * scale_elem_size;
2796+
2797+ // src1_row [D, 1, 1, 1] -> input
2798+ src1_row.ne [1 ] = 1 ;
2799+ src1_row.ne [2 ] = 1 ;
2800+ src1_row.ne [3 ] = 1 ;
2801+ src1_row.nb [2 ] = nb11;
2802+ src1_row.nb [3 ] = nb11;
2803+
2804+ // dst_row [M, 1, 1, 1] -> out
2805+ dst_row.ne [1 ] = 1 ;
2806+ dst_row.ne [2 ] = 1 ;
2807+ dst_row.ne [3 ] = 1 ;
2808+ dst_row.nb [2 ] = nb1;
2809+ dst_row.nb [3 ] = nb1;
2810+
2811+ // create weight for one row
2812+ ggml_cann_pool_alloc weight_allocator (ctx.pool ());
2813+ void * weight_buffer = weight_allocator.alloc (nb02);
2814+ for (int64_t iid1 = 0 ; iid1 < ids->ne [1 ]; iid1++) {
2815+ for (int64_t id = 0 ; id < n_ids; id++) {
2816+ // expert index
2817+ int32_t i02 = *(int32_t *) (ids_host.data () + iid1*ids->nb [1 ] + id*ids->nb [0 ]);
2818+ GGML_ASSERT (i02 >= 0 && i02 < n_as);
2819+
2820+ // If B = 1 (broadcast), always use 0; otherwise, use id.
2821+ int64_t i11 = (ne11 == 1 ? 0 : id);
2822+ int64_t i12 = iid1;
2823+
2824+ int64_t i1 = id;
2825+ int64_t i2 = i12;
2826+
2827+ void * src0_tmp_ptr = src0_original + i02*weight_stride;
2828+ void * scale_tmp_ptr = src0_original + weight_size + i02*scale_stride;
2829+ void * src1_tmp_ptr = src1_original + i11*nb11 + i12*nb12;
2830+ void * dst_tmp_ptr = dst_original + i1*nb1 + i2*nb2;
2831+
2832+ // mem cpy
2833+ ggml_cann_async_memcpy (ctx, weight_buffer, src0_tmp_ptr, weight_stride,
2834+ ACL_MEMCPY_DEVICE_TO_DEVICE);
2835+ void * scale_buffer = (char *)weight_buffer + weight_stride;
2836+ ggml_cann_async_memcpy (ctx, scale_buffer, scale_tmp_ptr, scale_stride,
2837+ ACL_MEMCPY_DEVICE_TO_DEVICE);
2838+
2839+ src0_row.data = weight_buffer;
2840+ src1_row.data = src1_tmp_ptr;
2841+ dst_row.data = dst_tmp_ptr;
2842+ dst_row.src [0 ] = &src0_row;
2843+ dst_row.src [1 ] = &src1_row;
2844+
2845+ ggml_cann_mul_mat (ctx, &dst_row);
2846+ }
2847+ }
2848+ return ;
2849+ }
2850+
27282851void ggml_cann_mul_mat_id (ggml_backend_cann_context& ctx, ggml_tensor* dst) {
27292852 const enum ggml_type type = dst->src [0 ]->type ;
27302853 switch (type) {
27312854 case GGML_TYPE_F32:
27322855 case GGML_TYPE_F16:
27332856 ggml_cann_mul_mat_id_fp (ctx, dst);
27342857 break ;
2858+ case GGML_TYPE_Q4_0:
2859+ case GGML_TYPE_Q8_0:
2860+ ggml_cann_mul_mat_id_quant (ctx, dst);
2861+ break ;
27352862 default :
27362863 GGML_ABORT (" Unsupported type for mul_mat_id" );
27372864 break ;
0 commit comments