3333#include  < aclnnop/aclnn_group_norm.h> 
3434#include  < aclnnop/aclnn_index_fill_tensor.h> 
3535#include  < aclnnop/aclnn_layer_norm.h> 
36+ #include  < aclnnop/aclnn_mm.h> 
37+ #include  < aclnnop/aclnn_batch_matmul.h> 
3638#include  < aclnnop/aclnn_matmul.h> 
3739#include  < aclnnop/aclnn_max_pool.h> 
3840#include  < aclnnop/aclnn_permute.h> 
@@ -2423,7 +2425,6 @@ static void aclnn_mat_mul(ggml_backend_cann_context& ctx, aclTensor* acl_input,
24232425                          aclTensor* acl_weight, aclTensor* acl_dst) {
24242426    int8_t  cube_math_type = 1 ;  //  ALLOW_FP32_DOWN_PRECISION, when input is
24252427                                //  fp32, atlas a2 will transpose it to HFLOAT32.
2426- 
24272428    uint64_t  workspaceSize = 0 ;
24282429    aclOpExecutor* executor;
24292430    void * workspaceAddr = nullptr ;
@@ -2441,6 +2442,80 @@ static void aclnn_mat_mul(ggml_backend_cann_context& ctx, aclTensor* acl_input,
24412442        aclnnMatmul (workspaceAddr, workspaceSize, executor, ctx.stream ()));
24422443}
24432444
2445+ /* *
2446+  * @brief Performs matrix multiplication of two 2D tensors. 
2447+  * 
2448+  * This function computes the matrix multiplication of the input tensor 
2449+  * `acl_input` and the weight tensor `acl_weight`, and stores the result in the 
2450+  * destination tensor `acl_dst`. 
2451+  * The operation is defined as: 
2452+  * \f[ 
2453+  *     \text {acl_dst}=\text {acl_input@acl_weight} 
2454+  * \f] 
2455+  * 
2456+  * @param ctx The context for the CANN backend operations. 
2457+  * @param acl_input The input tensor for the matrix multiplication. 
2458+  * @param acl_weight The weight tensor for the matrix multiplication. 
2459+  * @param acl_dst The destination tensor where the result of the matrix 
2460+  * multiplication will be stored. 
2461+  */  
2462+ static  void  aclnn_mat_mul_2d (ggml_backend_cann_context& ctx, aclTensor* acl_input,
2463+                              aclTensor* acl_weight, aclTensor* acl_dst) {
2464+     int8_t  cube_math_type = 2 ;
2465+     uint64_t  workspaceSize = 0 ;
2466+     aclOpExecutor* executor;
2467+     void * workspaceAddr = nullptr ;
2468+ 
2469+     ACL_CHECK (aclnnMmGetWorkspaceSize (acl_input, acl_weight, acl_dst,
2470+                                       cube_math_type, &workspaceSize,
2471+                                       &executor));
2472+ 
2473+     if  (workspaceSize > 0 ) {
2474+         ggml_cann_pool_alloc workspace_allocator (ctx.pool (), workspaceSize);
2475+         workspaceAddr = workspace_allocator.get ();
2476+     }
2477+ 
2478+     ACL_CHECK (
2479+         aclnnMm (workspaceAddr, workspaceSize, executor, ctx.stream ()));
2480+ }
2481+ 
2482+ /* *
2483+  * @brief Performs matrix multiplication of two 3D tensors. 
2484+  * 
2485+  * This function computes the matrix multiplication of the input tensor 
2486+  * `acl_input` and the weight tensor `acl_weight`, and stores the result in the 
2487+  * destination tensor `acl_dst`. 
2488+  * The operation is defined as: 
2489+  * \f[ 
2490+  *     \text {acl_dst}=\text {acl_input@acl_weight} 
2491+  * \f] 
2492+  * 
2493+  * @param ctx The context for the CANN backend operations. 
2494+  * @param acl_input The input tensor for the matrix multiplication. 
2495+  * @param acl_weight The weight tensor for the matrix multiplication. 
2496+  * @param acl_dst The destination tensor where the result of the matrix 
2497+  * multiplication will be stored. 
2498+  */  
2499+ static  void  aclnn_mat_mul_3d (ggml_backend_cann_context& ctx, aclTensor* acl_input,
2500+                              aclTensor* acl_weight, aclTensor* acl_dst) {
2501+     int8_t  cube_math_type = 2 ;
2502+     uint64_t  workspaceSize = 0 ;
2503+     aclOpExecutor* executor;
2504+     void * workspaceAddr = nullptr ;
2505+ 
2506+     ACL_CHECK (aclnnBatchMatMulGetWorkspaceSize (acl_input, acl_weight, acl_dst,
2507+                                                cube_math_type, &workspaceSize,
2508+                                                &executor));
2509+ 
2510+     if  (workspaceSize > 0 ) {
2511+         ggml_cann_pool_alloc workspace_allocator (ctx.pool (), workspaceSize);
2512+         workspaceAddr = workspace_allocator.get ();
2513+     }
2514+ 
2515+     ACL_CHECK (
2516+         aclnnBatchMatMul (workspaceAddr, workspaceSize, executor, ctx.stream ()));
2517+ }
2518+ 
24442519/* *
24452520 * @brief Performs matrix multiplication with floating-point precision on 
24462521 * tensors using the CANN backend. 
@@ -2462,20 +2537,43 @@ static void ggml_cann_mat_mul_fp(ggml_backend_cann_context& ctx,
24622537    //  broadcast, when weight ne2 or ne3 is not 1, weight need repeat.
24632538    BCAST_MUL_MAT_SHAPE (input, weight, dst);
24642539
2465-     //  transpose weight: [1,2,3,4] -> [1,2,4,3]
2466-     int64_t  transpose_ne[] = {bcast_weight_ne[1 ], bcast_weight_ne[0 ],
2467-                               bcast_weight_ne[2 ], bcast_weight_ne[3 ],
2468-                               bcast_weight_ne[4 ], bcast_weight_ne[5 ]};
2469-     size_t  transpose_nb[] = {bcast_weight_nb[1 ], bcast_weight_nb[0 ],
2470-                              bcast_weight_nb[2 ], bcast_weight_nb[3 ],
2471-                              bcast_weight_nb[4 ], bcast_weight_nb[5 ]};
2540+     int64_t  n_dims = bcast_dims;
2541+     if  (bcast_input_ne[3 ] == bcast_weight_ne[3 ] && bcast_input_ne[3 ] == 1 ) {
2542+         if  (bcast_input_ne[2 ] == 1  && bcast_weight_ne[2 ] == 1 ) {
2543+             n_dims = 2 ;
2544+         } else  if  (bcast_input_ne[2 ] == 1 ) {
2545+             n_dims = 3 ;
2546+         }
2547+     }
24722548
2473-     aclTensor* acl_weight_tensor =
2474-         ggml_cann_create_tensor (weight, transpose_ne, transpose_nb, bcast_dims);
24752549    aclTensor* acl_input_tensor =
2476-         ggml_cann_create_tensor (input, BCAST_MUL_MAT_PARAM (input));
2477-     aclTensor* acl_dst = ggml_cann_create_tensor (dst, BCAST_MUL_MAT_PARAM (dst));
2478-     aclnn_mat_mul (ctx, acl_input_tensor, acl_weight_tensor, acl_dst);
2550+         ggml_cann_create_tensor (input, bcast_input_ne, bcast_input_nb, n_dims);
2551+     int64_t  transpose_ne[] = {
2552+         bcast_weight_ne[1 ], bcast_weight_ne[0 ],
2553+         bcast_weight_ne[2 ], bcast_weight_ne[3 ],
2554+         bcast_weight_ne[4 ], bcast_weight_ne[5 ]
2555+     };
2556+     size_t  transpose_nb[] = {
2557+         bcast_weight_nb[1 ], bcast_weight_nb[0 ],
2558+         bcast_weight_nb[2 ], bcast_weight_nb[3 ],
2559+         bcast_weight_nb[4 ], bcast_weight_nb[5 ]
2560+     };
2561+     aclTensor* acl_weight_tensor =
2562+         ggml_cann_create_tensor (weight, transpose_ne, transpose_nb, n_dims);
2563+     aclTensor* acl_dst =
2564+         ggml_cann_create_tensor (dst, bcast_dst_ne, bcast_dst_nb, n_dims);
2565+ 
2566+     switch  (n_dims) {
2567+     case  2 :
2568+         aclnn_mat_mul_2d (ctx, acl_input_tensor, acl_weight_tensor, acl_dst);
2569+         break ;
2570+     case  3 :
2571+         aclnn_mat_mul_3d (ctx, acl_input_tensor, acl_weight_tensor, acl_dst);
2572+         break ;
2573+     default :
2574+         aclnn_mat_mul (ctx, acl_input_tensor, acl_weight_tensor, acl_dst);
2575+         break ;
2576+     }
24792577
24802578    ACL_CHECK (aclDestroyTensor (acl_weight_tensor));
24812579    ACL_CHECK (aclDestroyTensor (acl_input_tensor));
@@ -2501,46 +2599,40 @@ static void ggml_cann_mul_mat_quant(ggml_backend_cann_context& ctx,
25012599    ggml_tensor* src0 = dst->src [0 ];  //  weight
25022600    ggml_tensor* src1 = dst->src [1 ];  //  input
25032601
2504-     //  The shape of the weight is NCHW. Matrix multiplication uses HW dims. HC
2505-     //  is regarded as batch. weight need transpose.
2506-     int64_t  weight_ne[] = {src0->ne [1 ], src0->ne [0 ]};
2602+     //  The shape of the weight is NCHW.
2603+     //  Matrix multiplication uses HW dims.
2604+     //  HC is regarded as batch.
2605+     //  weight need transpose.
25072606    float  weight_elem_size;
25082607    if  (type == GGML_TYPE_Q4_0) {
25092608        weight_elem_size = float (sizeof (uint8_t )) / 2 ;
2510-     }
2511-     else  if  (type == GGML_TYPE_Q8_0) {
2609+     } else  if  (type == GGML_TYPE_Q8_0) {
25122610        weight_elem_size = float (sizeof (uint8_t ));
2513-     }
2514-     else  {
2611+     } else  {
25152612        GGML_ABORT (" Only support Q4_0 and Q8_0 MUL_MAT"  );
25162613    }
2517-     float  weight_nb[] = {weight_elem_size * src0->ne [0 ], weight_elem_size};
2518- 
2519-     //  size of one matrix is element_size * height * width.
2520-     size_t  weight_stride = weight_elem_size * src0->ne [0 ] * src0->ne [1 ];
2614+     float  weight_nb[] = {src0->ne [0 ] * weight_elem_size, weight_elem_size};
2615+     size_t  weight_stride = src0->ne [1 ] * src0->ne [0 ] * weight_elem_size;
25212616    size_t  weight_size = weight_stride * src0->ne [2 ] * src0->ne [3 ];
25222617
25232618    //  scale stored at the end of weight. Also need transpose.
2524-     GGML_ASSERT (QK4_0 == QK8_0);
2525-     int64_t  scale_ne[] = {src0->ne [1 ], src0->ne [0 ] / QK8_0};
25262619    size_t  scale_elem_size = sizeof (uint16_t );
2527-     size_t  scale_nb[] = {src0->ne [0 ] / QK8_0 * scale_elem_size,
2528-                          scale_elem_size};
2529-     size_t  scale_stride = scale_elem_size * src0->ne [0 ] * src0->ne [1 ] / QK8_0;
2620+     size_t  scale_nb[] = {src0->ne [0 ] / QK8_0 * scale_elem_size, scale_elem_size};
2621+     size_t  scale_stride = src0->ne [1 ] * src0->ne [0 ] / QK8_0 * scale_elem_size;
25302622    char * scale_offset = (char *)src0->data  + weight_size;
25312623
25322624    //  input
2533-     void * input_buffer;
25342625    size_t  input_elem_size = sizeof (uint16_t );
25352626    int64_t  input_ne[] = {src1->ne [0 ], src1->ne [1 ]};
2536-     size_t  input_nb[] = {input_elem_size, input_elem_size * src1->ne [0 ]};
2537-     size_t  input_stride = input_elem_size * src1->ne [0 ] * src1->ne [1 ];
2538- 
2627+     size_t  input_nb[] = {input_elem_size,  input_ne[0 ] * input_elem_size};
2628+     size_t  input_stride = input_ne[0 ] * input_ne[1 ] * input_elem_size;
25392629    ggml_cann_pool_alloc input_alloctor (ctx.pool ());
2630+     void * input_buffer = src1->data ;
2631+ 
2632+     //  case in
25402633    if  (src1->type  != GGML_TYPE_F16) {
25412634        aclTensor* acl_src1_tensor = ggml_cann_create_tensor (src1);
2542-         input_alloctor.alloc (ggml_nelements (src1) * input_elem_size);
2543-         input_buffer = input_alloctor.get ();
2635+         input_buffer = input_alloctor.alloc (ggml_nelements (src1) * input_elem_size);
25442636
25452637        int64_t * input_cast_ne = src1->ne ;
25462638        size_t  input_cast_nb[GGML_MAX_DIMS];
@@ -2550,88 +2642,139 @@ static void ggml_cann_mul_mat_quant(ggml_backend_cann_context& ctx,
25502642        }
25512643
25522644        aclTensor* acl_input_tensor = ggml_cann_create_tensor (
2553-             input_buffer, ACL_FLOAT16, input_elem_size, input_cast_ne,
2554-             input_cast_nb, GGML_MAX_DIMS);
2645+             input_buffer,
2646+             ACL_FLOAT16,
2647+             input_elem_size, input_cast_ne, input_cast_nb, GGML_MAX_DIMS);
25552648        aclnn_cast (ctx, acl_src1_tensor, acl_input_tensor, ACL_FLOAT16);
2649+ 
25562650        ACL_CHECK (aclDestroyTensor (acl_input_tensor));
25572651        ACL_CHECK (aclDestroyTensor (acl_src1_tensor));
2558-     } else  {
2559-         input_buffer = src1->data ;
25602652    }
25612653
25622654    //  output
25632655    size_t  output_elem_size = sizeof (uint16_t );
2564-     int64_t  output_ne[] = {dst->ne [0 ], dst->ne [1 ]};
2565-     size_t  output_nb[] = {output_elem_size, output_elem_size * dst->ne [0 ]};
2566-     ggml_cann_pool_alloc output_alloctor (
2567-         ctx.pool (), ggml_nelements (dst) * output_elem_size);
2568-     void * output_buffer = output_alloctor.get ();
2569-     size_t  output_stride = output_elem_size * dst->ne [0 ] * dst->ne [1 ];
2656+     size_t  output_nb[] = {output_elem_size, dst->ne [0 ] * output_elem_size};
2657+     ggml_cann_pool_alloc output_allocator (ctx.pool ());
2658+     void * output_buffer = output_allocator.alloc (ggml_nelements (dst) * output_elem_size);
2659+     size_t  output_stride = dst->ne [0 ] * dst->ne [1 ] * output_elem_size;
25702660
25712661    //  aclnn
2662+     int64_t  max_elem_size = 65535 ;
2663+     int64_t  split_size = (src0->ne [1 ] / max_elem_size) + 1 ;
2664+     ggml_cann_pool_alloc workspace_allocator (ctx.pool ());
2665+     aclOpExecutor* executor = nullptr ;
25722666    uint64_t  workspaceSize = 0 ;
2573-     aclOpExecutor* executor;
25742667    void * workspaceAddr = nullptr ;
2575- 
25762668    for  (int64_t  n1 = 0 ; n1 < src1->ne [3 ]; n1++) {
25772669        for  (int64_t  c1 = 0 ; c1 < src1->ne [2 ]; c1++) {
25782670            int64_t  n0 = n1 / (src1->ne [3 ] / src0->ne [3 ]);
25792671            int64_t  c0 = c1 / (src1->ne [2 ] / src0->ne [2 ]);
25802672
2581-             int64_t  batch1 = n1 * src1->ne [2 ] + c1;
2582-             int64_t  batch0 = n0 * src0->ne [2 ] + c0;
2673+             int64_t  batch1 = ( n1 * src1->ne [2 ])  + c1;
2674+             int64_t  batch0 = ( n0 * src0->ne [2 ])  + c0;
25832675
25842676            aclTensor* acl_input_tensor = ggml_cann_create_tensor (
25852677                (char *)input_buffer + batch1 * input_stride, ACL_FLOAT16,
25862678                input_elem_size, input_ne, input_nb, 2 );
2679+ 
2680+             //  first split
2681+             int64_t  weight_ne_offset = 0 ;
2682+             int64_t  weight_ne[2 ] = {max_elem_size > src0->ne [1 ] ? src0->ne [1 ] : max_elem_size, src0->ne [0 ]};
2683+             int64_t  scale_ne_offset = 0 ;
2684+             int64_t  scale_ne[2 ] = {weight_ne[0 ], weight_ne[1 ] / QK8_0};
2685+             int64_t  output_ne_offset = 0 ;
2686+             int64_t  output_ne[2 ] = {weight_ne[0 ], dst->ne [1 ]};
2687+ 
25872688            aclTensor* acl_weight_tensor = ggml_cann_create_tensor (
25882689                (char *)src0->data  + batch0 * weight_stride,
2589-                 ggml_cann_type_mapping (type), weight_elem_size, weight_ne,
2590-                 weight_nb, 2 );
2690+                 ggml_cann_type_mapping (type),
2691+                 weight_elem_size, weight_ne, weight_nb, 2 ,
2692+                 ACL_FORMAT_ND, weight_ne_offset);
25912693            aclTensor* acl_scale_tensor = ggml_cann_create_tensor (
2592-                 scale_offset + batch0 * scale_stride, ACL_FLOAT16,
2593-                 scale_elem_size, scale_ne, scale_nb, 2 );
2694+                 scale_offset + batch0 * scale_stride,
2695+                 ACL_FLOAT16,
2696+                 scale_elem_size, scale_ne, scale_nb, 2 ,
2697+                 ACL_FORMAT_ND, scale_ne_offset);
25942698            aclTensor* acl_output_tensor = ggml_cann_create_tensor (
2595-                 (char *)output_buffer + batch1 * output_stride, ACL_FLOAT16,
2596-                 output_elem_size, output_ne, output_nb, 2 );
2699+                 (char *)output_buffer + batch1 * output_stride,
2700+                 ACL_FLOAT16,
2701+                 output_elem_size, output_ne, output_nb, 2 ,
2702+                 ACL_FORMAT_ND, output_ne_offset);
25972703
25982704            ACL_CHECK (aclnnWeightQuantBatchMatmulV2GetWorkspaceSize (
2599-                 acl_input_tensor, acl_weight_tensor, acl_scale_tensor, nullptr ,
2600-                 nullptr , nullptr , nullptr , QK8_0, acl_output_tensor,
2601-                 &workspaceSize, &executor));
2602- 
2603-             if  (workspaceSize > 0  && workspaceAddr == nullptr ) {
2604-                 ggml_cann_pool_alloc workspace_allocator (ctx.pool (),
2605-                                                          workspaceSize);
2606-                 workspaceAddr = workspace_allocator.get ();
2705+                 acl_input_tensor, acl_weight_tensor, acl_scale_tensor,
2706+                 nullptr , nullptr , nullptr , nullptr , QK8_0,
2707+                 acl_output_tensor, &workspaceSize, &executor));
2708+             if  (workspaceAddr == nullptr ) {
2709+                 workspaceAddr = workspace_allocator.alloc (workspaceSize);
26072710            }
2608- 
26092711            ACL_CHECK (aclnnWeightQuantBatchMatmulV2 (
26102712                workspaceAddr, workspaceSize, executor, ctx.stream ()));
26112713
2612-             ACL_CHECK (aclDestroyTensor (acl_input_tensor));
26132714            ACL_CHECK (aclDestroyTensor (acl_weight_tensor));
26142715            ACL_CHECK (aclDestroyTensor (acl_scale_tensor));
26152716            ACL_CHECK (aclDestroyTensor (acl_output_tensor));
2717+ 
2718+             //  other splits
2719+             for  (int64_t  split = 1 ; split < split_size; split++) {
2720+                 weight_ne_offset += weight_elem_size * weight_ne[0 ] * weight_ne[1 ];
2721+                 weight_ne[0 ] = max_elem_size * (split + 1 ) > src0->ne [1 ] ? src0->ne [1 ] - (max_elem_size * split) : max_elem_size;
2722+                 scale_ne_offset += scale_elem_size * scale_ne[0 ] * scale_ne[1 ];
2723+                 scale_ne[0 ] = weight_ne[0 ];
2724+                 output_ne_offset += output_elem_size * output_ne[0 ] * output_ne[1 ];
2725+                 output_ne[0 ] = weight_ne[0 ];
2726+ 
2727+                 acl_weight_tensor = ggml_cann_create_tensor (
2728+                     (char *)src0->data  + batch0 * weight_stride,
2729+                     ggml_cann_type_mapping (type),
2730+                     weight_elem_size, weight_ne, weight_nb, 2 ,
2731+                     ACL_FORMAT_ND, weight_ne_offset);
2732+                 acl_scale_tensor = ggml_cann_create_tensor (
2733+                     scale_offset + batch0 * scale_stride,
2734+                     ACL_FLOAT16,
2735+                     scale_elem_size, scale_ne, scale_nb, 2 ,
2736+                     ACL_FORMAT_ND, scale_ne_offset);
2737+                 acl_output_tensor = ggml_cann_create_tensor (
2738+                     (char *)output_buffer + batch1 * output_stride,
2739+                     ACL_FLOAT16,
2740+                     output_elem_size, output_ne, output_nb, 2 ,
2741+                     ACL_FORMAT_ND, output_ne_offset);
2742+ 
2743+                 ACL_CHECK (aclnnWeightQuantBatchMatmulV2GetWorkspaceSize (
2744+                     acl_input_tensor, acl_weight_tensor, acl_scale_tensor,
2745+                     nullptr , nullptr , nullptr , nullptr , QK8_0,
2746+                     acl_output_tensor, &workspaceSize, &executor));
2747+                 ACL_CHECK (aclnnWeightQuantBatchMatmulV2 (
2748+                     workspaceAddr, workspaceSize, executor, ctx.stream ()));
2749+ 
2750+                 ACL_CHECK (aclDestroyTensor (acl_weight_tensor));
2751+                 ACL_CHECK (aclDestroyTensor (acl_scale_tensor));
2752+                 ACL_CHECK (aclDestroyTensor (acl_output_tensor));
2753+             }
2754+ 
2755+             ACL_CHECK (aclDestroyTensor (acl_input_tensor));
26162756        }
26172757    }
26182758
26192759    //  cast out
2620-     int64_t * output_cast_ne = dst->ne ;
2621-     size_t  output_cast_nb[GGML_MAX_DIMS];
2622-     output_cast_nb[0 ] = sizeof (uint16_t );
2623-     for  (int  i = 1 ; i < GGML_MAX_DIMS; i++) {
2624-         output_cast_nb[i] = output_cast_nb[i - 1 ] * output_cast_ne[i - 1 ];
2625-     }
2760+     if  (dst->type  != GGML_TYPE_F16) {
2761+         int64_t * output_cast_ne = dst->ne ;
2762+         size_t  output_cast_nb[GGML_MAX_DIMS];
2763+         output_cast_nb[0 ] = sizeof (uint16_t );
2764+         for  (int  i = 1 ; i < GGML_MAX_DIMS; i++) {
2765+             output_cast_nb[i] = output_cast_nb[i - 1 ] * output_cast_ne[i - 1 ];
2766+         }
26262767
2627-     aclTensor* acl_output_tensor =
2628-         ggml_cann_create_tensor (output_buffer, ACL_FLOAT16, output_elem_size,
2629-                                 output_cast_ne, output_cast_nb, GGML_MAX_DIMS);
2630-     aclTensor* acl_dst_tensor = ggml_cann_create_tensor (dst);
2631-     aclnn_cast (ctx, acl_output_tensor, acl_dst_tensor, ACL_FLOAT);
2768+         aclTensor* acl_output_tensor = ggml_cann_create_tensor (
2769+             output_buffer,
2770+             ACL_FLOAT16,
2771+             output_elem_size, output_cast_ne, output_cast_nb, GGML_MAX_DIMS);
2772+         aclTensor* acl_dst_tensor = ggml_cann_create_tensor (dst);
2773+         aclnn_cast (ctx, acl_output_tensor, acl_dst_tensor, ggml_cann_type_mapping (dst->type ));
26322774
2633-     ACL_CHECK (aclDestroyTensor (acl_output_tensor));
2634-     ACL_CHECK (aclDestroyTensor (acl_dst_tensor));
2775+         ACL_CHECK (aclDestroyTensor (acl_output_tensor));
2776+         ACL_CHECK (aclDestroyTensor (acl_dst_tensor));
2777+     }
26352778}
26362779
26372780void  ggml_cann_mul_mat (ggml_backend_cann_context& ctx, ggml_tensor* dst) {
0 commit comments