3333#include < aclnnop/aclnn_group_norm.h>
3434#include < aclnnop/aclnn_index_fill_tensor.h>
3535#include < aclnnop/aclnn_layer_norm.h>
36+ #include < aclnnop/aclnn_mm.h>
37+ #include < aclnnop/aclnn_batch_matmul.h>
3638#include < aclnnop/aclnn_matmul.h>
3739#include < aclnnop/aclnn_max_pool.h>
3840#include < aclnnop/aclnn_permute.h>
@@ -2423,7 +2425,6 @@ static void aclnn_mat_mul(ggml_backend_cann_context& ctx, aclTensor* acl_input,
24232425 aclTensor* acl_weight, aclTensor* acl_dst) {
24242426 int8_t cube_math_type = 1 ; // ALLOW_FP32_DOWN_PRECISION, when input is
24252427 // fp32, atlas a2 will transpose it to HFLOAT32.
2426-
24272428 uint64_t workspaceSize = 0 ;
24282429 aclOpExecutor* executor;
24292430 void * workspaceAddr = nullptr ;
@@ -2441,6 +2442,80 @@ static void aclnn_mat_mul(ggml_backend_cann_context& ctx, aclTensor* acl_input,
24412442 aclnnMatmul (workspaceAddr, workspaceSize, executor, ctx.stream ()));
24422443}
24432444
2445+ /* *
2446+ * @brief Performs matrix multiplication of two 2D tensors.
2447+ *
2448+ * This function computes the matrix multiplication of the input tensor
2449+ * `acl_input` and the weight tensor `acl_weight`, and stores the result in the
2450+ * destination tensor `acl_dst`.
2451+ * The operation is defined as:
2452+ * \f[
2453+ * \text {acl_dst}=\text {acl_input@acl_weight}
2454+ * \f]
2455+ *
2456+ * @param ctx The context for the CANN backend operations.
2457+ * @param acl_input The input tensor for the matrix multiplication.
2458+ * @param acl_weight The weight tensor for the matrix multiplication.
2459+ * @param acl_dst The destination tensor where the result of the matrix
2460+ * multiplication will be stored.
2461+ */
2462+ static void aclnn_mat_mul_2d (ggml_backend_cann_context& ctx, aclTensor* acl_input,
2463+ aclTensor* acl_weight, aclTensor* acl_dst) {
2464+ int8_t cube_math_type = 2 ;
2465+ uint64_t workspaceSize = 0 ;
2466+ aclOpExecutor* executor;
2467+ void * workspaceAddr = nullptr ;
2468+
2469+ ACL_CHECK (aclnnMmGetWorkspaceSize (acl_input, acl_weight, acl_dst,
2470+ cube_math_type, &workspaceSize,
2471+ &executor));
2472+
2473+ if (workspaceSize > 0 ) {
2474+ ggml_cann_pool_alloc workspace_allocator (ctx.pool (), workspaceSize);
2475+ workspaceAddr = workspace_allocator.get ();
2476+ }
2477+
2478+ ACL_CHECK (
2479+ aclnnMm (workspaceAddr, workspaceSize, executor, ctx.stream ()));
2480+ }
2481+
2482+ /* *
2483+ * @brief Performs matrix multiplication of two 3D tensors.
2484+ *
2485+ * This function computes the matrix multiplication of the input tensor
2486+ * `acl_input` and the weight tensor `acl_weight`, and stores the result in the
2487+ * destination tensor `acl_dst`.
2488+ * The operation is defined as:
2489+ * \f[
2490+ * \text {acl_dst}=\text {acl_input@acl_weight}
2491+ * \f]
2492+ *
2493+ * @param ctx The context for the CANN backend operations.
2494+ * @param acl_input The input tensor for the matrix multiplication.
2495+ * @param acl_weight The weight tensor for the matrix multiplication.
2496+ * @param acl_dst The destination tensor where the result of the matrix
2497+ * multiplication will be stored.
2498+ */
2499+ static void aclnn_mat_mul_3d (ggml_backend_cann_context& ctx, aclTensor* acl_input,
2500+ aclTensor* acl_weight, aclTensor* acl_dst) {
2501+ int8_t cube_math_type = 2 ;
2502+ uint64_t workspaceSize = 0 ;
2503+ aclOpExecutor* executor;
2504+ void * workspaceAddr = nullptr ;
2505+
2506+ ACL_CHECK (aclnnBatchMatMulGetWorkspaceSize (acl_input, acl_weight, acl_dst,
2507+ cube_math_type, &workspaceSize,
2508+ &executor));
2509+
2510+ if (workspaceSize > 0 ) {
2511+ ggml_cann_pool_alloc workspace_allocator (ctx.pool (), workspaceSize);
2512+ workspaceAddr = workspace_allocator.get ();
2513+ }
2514+
2515+ ACL_CHECK (
2516+ aclnnBatchMatMul (workspaceAddr, workspaceSize, executor, ctx.stream ()));
2517+ }
2518+
24442519/* *
24452520 * @brief Performs matrix multiplication with floating-point precision on
24462521 * tensors using the CANN backend.
@@ -2462,20 +2537,43 @@ static void ggml_cann_mat_mul_fp(ggml_backend_cann_context& ctx,
24622537 // broadcast, when weight ne2 or ne3 is not 1, weight need repeat.
24632538 BCAST_MUL_MAT_SHAPE (input, weight, dst);
24642539
2465- // transpose weight: [1,2,3,4] -> [1,2,4,3]
2466- int64_t transpose_ne[] = {bcast_weight_ne[1 ], bcast_weight_ne[0 ],
2467- bcast_weight_ne[2 ], bcast_weight_ne[3 ],
2468- bcast_weight_ne[4 ], bcast_weight_ne[5 ]};
2469- size_t transpose_nb[] = {bcast_weight_nb[1 ], bcast_weight_nb[0 ],
2470- bcast_weight_nb[2 ], bcast_weight_nb[3 ],
2471- bcast_weight_nb[4 ], bcast_weight_nb[5 ]};
2540+ int64_t n_dims = bcast_dims;
2541+ if (bcast_input_ne[3 ] == bcast_weight_ne[3 ] && bcast_input_ne[3 ] == 1 ) {
2542+ if (bcast_input_ne[2 ] == 1 && bcast_weight_ne[2 ] == 1 ) {
2543+ n_dims = 2 ;
2544+ } else if (bcast_input_ne[2 ] == 1 ) {
2545+ n_dims = 3 ;
2546+ }
2547+ }
24722548
2473- aclTensor* acl_weight_tensor =
2474- ggml_cann_create_tensor (weight, transpose_ne, transpose_nb, bcast_dims);
24752549 aclTensor* acl_input_tensor =
2476- ggml_cann_create_tensor (input, BCAST_MUL_MAT_PARAM (input));
2477- aclTensor* acl_dst = ggml_cann_create_tensor (dst, BCAST_MUL_MAT_PARAM (dst));
2478- aclnn_mat_mul (ctx, acl_input_tensor, acl_weight_tensor, acl_dst);
2550+ ggml_cann_create_tensor (input, bcast_input_ne, bcast_input_nb, n_dims);
2551+ int64_t transpose_ne[] = {
2552+ bcast_weight_ne[1 ], bcast_weight_ne[0 ],
2553+ bcast_weight_ne[2 ], bcast_weight_ne[3 ],
2554+ bcast_weight_ne[4 ], bcast_weight_ne[5 ]
2555+ };
2556+ size_t transpose_nb[] = {
2557+ bcast_weight_nb[1 ], bcast_weight_nb[0 ],
2558+ bcast_weight_nb[2 ], bcast_weight_nb[3 ],
2559+ bcast_weight_nb[4 ], bcast_weight_nb[5 ]
2560+ };
2561+ aclTensor* acl_weight_tensor =
2562+ ggml_cann_create_tensor (weight, transpose_ne, transpose_nb, n_dims);
2563+ aclTensor* acl_dst =
2564+ ggml_cann_create_tensor (dst, bcast_dst_ne, bcast_dst_nb, n_dims);
2565+
2566+ switch (n_dims) {
2567+ case 2 :
2568+ aclnn_mat_mul_2d (ctx, acl_input_tensor, acl_weight_tensor, acl_dst);
2569+ break ;
2570+ case 3 :
2571+ aclnn_mat_mul_3d (ctx, acl_input_tensor, acl_weight_tensor, acl_dst);
2572+ break ;
2573+ default :
2574+ aclnn_mat_mul (ctx, acl_input_tensor, acl_weight_tensor, acl_dst);
2575+ break ;
2576+ }
24792577
24802578 ACL_CHECK (aclDestroyTensor (acl_weight_tensor));
24812579 ACL_CHECK (aclDestroyTensor (acl_input_tensor));
@@ -2501,46 +2599,40 @@ static void ggml_cann_mul_mat_quant(ggml_backend_cann_context& ctx,
25012599 ggml_tensor* src0 = dst->src [0 ]; // weight
25022600 ggml_tensor* src1 = dst->src [1 ]; // input
25032601
2504- // The shape of the weight is NCHW. Matrix multiplication uses HW dims. HC
2505- // is regarded as batch. weight need transpose.
2506- int64_t weight_ne[] = {src0->ne [1 ], src0->ne [0 ]};
2602+ // The shape of the weight is NCHW.
2603+ // Matrix multiplication uses HW dims.
2604+ // HC is regarded as batch.
2605+ // weight need transpose.
25072606 float weight_elem_size;
25082607 if (type == GGML_TYPE_Q4_0) {
25092608 weight_elem_size = float (sizeof (uint8_t )) / 2 ;
2510- }
2511- else if (type == GGML_TYPE_Q8_0) {
2609+ } else if (type == GGML_TYPE_Q8_0) {
25122610 weight_elem_size = float (sizeof (uint8_t ));
2513- }
2514- else {
2611+ } else {
25152612 GGML_ABORT (" Only support Q4_0 and Q8_0 MUL_MAT" );
25162613 }
2517- float weight_nb[] = {weight_elem_size * src0->ne [0 ], weight_elem_size};
2518-
2519- // size of one matrix is element_size * height * width.
2520- size_t weight_stride = weight_elem_size * src0->ne [0 ] * src0->ne [1 ];
2614+ float weight_nb[] = {src0->ne [0 ] * weight_elem_size, weight_elem_size};
2615+ size_t weight_stride = src0->ne [1 ] * src0->ne [0 ] * weight_elem_size;
25212616 size_t weight_size = weight_stride * src0->ne [2 ] * src0->ne [3 ];
25222617
25232618 // scale stored at the end of weight. Also need transpose.
2524- GGML_ASSERT (QK4_0 == QK8_0);
2525- int64_t scale_ne[] = {src0->ne [1 ], src0->ne [0 ] / QK8_0};
25262619 size_t scale_elem_size = sizeof (uint16_t );
2527- size_t scale_nb[] = {src0->ne [0 ] / QK8_0 * scale_elem_size,
2528- scale_elem_size};
2529- size_t scale_stride = scale_elem_size * src0->ne [0 ] * src0->ne [1 ] / QK8_0;
2620+ size_t scale_nb[] = {src0->ne [0 ] / QK8_0 * scale_elem_size, scale_elem_size};
2621+ size_t scale_stride = src0->ne [1 ] * src0->ne [0 ] / QK8_0 * scale_elem_size;
25302622 char * scale_offset = (char *)src0->data + weight_size;
25312623
25322624 // input
2533- void * input_buffer;
25342625 size_t input_elem_size = sizeof (uint16_t );
25352626 int64_t input_ne[] = {src1->ne [0 ], src1->ne [1 ]};
2536- size_t input_nb[] = {input_elem_size, input_elem_size * src1->ne [0 ]};
2537- size_t input_stride = input_elem_size * src1->ne [0 ] * src1->ne [1 ];
2538-
2627+ size_t input_nb[] = {input_elem_size, input_ne[0 ] * input_elem_size};
2628+ size_t input_stride = input_ne[0 ] * input_ne[1 ] * input_elem_size;
25392629 ggml_cann_pool_alloc input_alloctor (ctx.pool ());
2630+ void * input_buffer = src1->data ;
2631+
2632+ // case in
25402633 if (src1->type != GGML_TYPE_F16) {
25412634 aclTensor* acl_src1_tensor = ggml_cann_create_tensor (src1);
2542- input_alloctor.alloc (ggml_nelements (src1) * input_elem_size);
2543- input_buffer = input_alloctor.get ();
2635+ input_buffer = input_alloctor.alloc (ggml_nelements (src1) * input_elem_size);
25442636
25452637 int64_t * input_cast_ne = src1->ne ;
25462638 size_t input_cast_nb[GGML_MAX_DIMS];
@@ -2550,88 +2642,139 @@ static void ggml_cann_mul_mat_quant(ggml_backend_cann_context& ctx,
25502642 }
25512643
25522644 aclTensor* acl_input_tensor = ggml_cann_create_tensor (
2553- input_buffer, ACL_FLOAT16, input_elem_size, input_cast_ne,
2554- input_cast_nb, GGML_MAX_DIMS);
2645+ input_buffer,
2646+ ACL_FLOAT16,
2647+ input_elem_size, input_cast_ne, input_cast_nb, GGML_MAX_DIMS);
25552648 aclnn_cast (ctx, acl_src1_tensor, acl_input_tensor, ACL_FLOAT16);
2649+
25562650 ACL_CHECK (aclDestroyTensor (acl_input_tensor));
25572651 ACL_CHECK (aclDestroyTensor (acl_src1_tensor));
2558- } else {
2559- input_buffer = src1->data ;
25602652 }
25612653
25622654 // output
25632655 size_t output_elem_size = sizeof (uint16_t );
2564- int64_t output_ne[] = {dst->ne [0 ], dst->ne [1 ]};
2565- size_t output_nb[] = {output_elem_size, output_elem_size * dst->ne [0 ]};
2566- ggml_cann_pool_alloc output_alloctor (
2567- ctx.pool (), ggml_nelements (dst) * output_elem_size);
2568- void * output_buffer = output_alloctor.get ();
2569- size_t output_stride = output_elem_size * dst->ne [0 ] * dst->ne [1 ];
2656+ size_t output_nb[] = {output_elem_size, dst->ne [0 ] * output_elem_size};
2657+ ggml_cann_pool_alloc output_allocator (ctx.pool ());
2658+ void * output_buffer = output_allocator.alloc (ggml_nelements (dst) * output_elem_size);
2659+ size_t output_stride = dst->ne [0 ] * dst->ne [1 ] * output_elem_size;
25702660
25712661 // aclnn
2662+ int64_t max_elem_size = 65535 ;
2663+ int64_t split_size = (src0->ne [1 ] / max_elem_size) + 1 ;
2664+ ggml_cann_pool_alloc workspace_allocator (ctx.pool ());
2665+ aclOpExecutor* executor = nullptr ;
25722666 uint64_t workspaceSize = 0 ;
2573- aclOpExecutor* executor;
25742667 void * workspaceAddr = nullptr ;
2575-
25762668 for (int64_t n1 = 0 ; n1 < src1->ne [3 ]; n1++) {
25772669 for (int64_t c1 = 0 ; c1 < src1->ne [2 ]; c1++) {
25782670 int64_t n0 = n1 / (src1->ne [3 ] / src0->ne [3 ]);
25792671 int64_t c0 = c1 / (src1->ne [2 ] / src0->ne [2 ]);
25802672
2581- int64_t batch1 = n1 * src1->ne [2 ] + c1;
2582- int64_t batch0 = n0 * src0->ne [2 ] + c0;
2673+ int64_t batch1 = ( n1 * src1->ne [2 ]) + c1;
2674+ int64_t batch0 = ( n0 * src0->ne [2 ]) + c0;
25832675
25842676 aclTensor* acl_input_tensor = ggml_cann_create_tensor (
25852677 (char *)input_buffer + batch1 * input_stride, ACL_FLOAT16,
25862678 input_elem_size, input_ne, input_nb, 2 );
2679+
2680+ // first split
2681+ int64_t weight_ne_offset = 0 ;
2682+ int64_t weight_ne[2 ] = {max_elem_size > src0->ne [1 ] ? src0->ne [1 ] : max_elem_size, src0->ne [0 ]};
2683+ int64_t scale_ne_offset = 0 ;
2684+ int64_t scale_ne[2 ] = {weight_ne[0 ], weight_ne[1 ] / QK8_0};
2685+ int64_t output_ne_offset = 0 ;
2686+ int64_t output_ne[2 ] = {weight_ne[0 ], dst->ne [1 ]};
2687+
25872688 aclTensor* acl_weight_tensor = ggml_cann_create_tensor (
25882689 (char *)src0->data + batch0 * weight_stride,
2589- ggml_cann_type_mapping (type), weight_elem_size, weight_ne,
2590- weight_nb, 2 );
2690+ ggml_cann_type_mapping (type),
2691+ weight_elem_size, weight_ne, weight_nb, 2 ,
2692+ ACL_FORMAT_ND, weight_ne_offset);
25912693 aclTensor* acl_scale_tensor = ggml_cann_create_tensor (
2592- scale_offset + batch0 * scale_stride, ACL_FLOAT16,
2593- scale_elem_size, scale_ne, scale_nb, 2 );
2694+ scale_offset + batch0 * scale_stride,
2695+ ACL_FLOAT16,
2696+ scale_elem_size, scale_ne, scale_nb, 2 ,
2697+ ACL_FORMAT_ND, scale_ne_offset);
25942698 aclTensor* acl_output_tensor = ggml_cann_create_tensor (
2595- (char *)output_buffer + batch1 * output_stride, ACL_FLOAT16,
2596- output_elem_size, output_ne, output_nb, 2 );
2699+ (char *)output_buffer + batch1 * output_stride,
2700+ ACL_FLOAT16,
2701+ output_elem_size, output_ne, output_nb, 2 ,
2702+ ACL_FORMAT_ND, output_ne_offset);
25972703
25982704 ACL_CHECK (aclnnWeightQuantBatchMatmulV2GetWorkspaceSize (
2599- acl_input_tensor, acl_weight_tensor, acl_scale_tensor, nullptr ,
2600- nullptr , nullptr , nullptr , QK8_0, acl_output_tensor,
2601- &workspaceSize, &executor));
2602-
2603- if (workspaceSize > 0 && workspaceAddr == nullptr ) {
2604- ggml_cann_pool_alloc workspace_allocator (ctx.pool (),
2605- workspaceSize);
2606- workspaceAddr = workspace_allocator.get ();
2705+ acl_input_tensor, acl_weight_tensor, acl_scale_tensor,
2706+ nullptr , nullptr , nullptr , nullptr , QK8_0,
2707+ acl_output_tensor, &workspaceSize, &executor));
2708+ if (workspaceAddr == nullptr ) {
2709+ workspaceAddr = workspace_allocator.alloc (workspaceSize);
26072710 }
2608-
26092711 ACL_CHECK (aclnnWeightQuantBatchMatmulV2 (
26102712 workspaceAddr, workspaceSize, executor, ctx.stream ()));
26112713
2612- ACL_CHECK (aclDestroyTensor (acl_input_tensor));
26132714 ACL_CHECK (aclDestroyTensor (acl_weight_tensor));
26142715 ACL_CHECK (aclDestroyTensor (acl_scale_tensor));
26152716 ACL_CHECK (aclDestroyTensor (acl_output_tensor));
2717+
2718+ // other splits
2719+ for (int64_t split = 1 ; split < split_size; split++) {
2720+ weight_ne_offset += weight_elem_size * weight_ne[0 ] * weight_ne[1 ];
2721+ weight_ne[0 ] = max_elem_size * (split + 1 ) > src0->ne [1 ] ? src0->ne [1 ] - (max_elem_size * split) : max_elem_size;
2722+ scale_ne_offset += scale_elem_size * scale_ne[0 ] * scale_ne[1 ];
2723+ scale_ne[0 ] = weight_ne[0 ];
2724+ output_ne_offset += output_elem_size * output_ne[0 ] * output_ne[1 ];
2725+ output_ne[0 ] = weight_ne[0 ];
2726+
2727+ acl_weight_tensor = ggml_cann_create_tensor (
2728+ (char *)src0->data + batch0 * weight_stride,
2729+ ggml_cann_type_mapping (type),
2730+ weight_elem_size, weight_ne, weight_nb, 2 ,
2731+ ACL_FORMAT_ND, weight_ne_offset);
2732+ acl_scale_tensor = ggml_cann_create_tensor (
2733+ scale_offset + batch0 * scale_stride,
2734+ ACL_FLOAT16,
2735+ scale_elem_size, scale_ne, scale_nb, 2 ,
2736+ ACL_FORMAT_ND, scale_ne_offset);
2737+ acl_output_tensor = ggml_cann_create_tensor (
2738+ (char *)output_buffer + batch1 * output_stride,
2739+ ACL_FLOAT16,
2740+ output_elem_size, output_ne, output_nb, 2 ,
2741+ ACL_FORMAT_ND, output_ne_offset);
2742+
2743+ ACL_CHECK (aclnnWeightQuantBatchMatmulV2GetWorkspaceSize (
2744+ acl_input_tensor, acl_weight_tensor, acl_scale_tensor,
2745+ nullptr , nullptr , nullptr , nullptr , QK8_0,
2746+ acl_output_tensor, &workspaceSize, &executor));
2747+ ACL_CHECK (aclnnWeightQuantBatchMatmulV2 (
2748+ workspaceAddr, workspaceSize, executor, ctx.stream ()));
2749+
2750+ ACL_CHECK (aclDestroyTensor (acl_weight_tensor));
2751+ ACL_CHECK (aclDestroyTensor (acl_scale_tensor));
2752+ ACL_CHECK (aclDestroyTensor (acl_output_tensor));
2753+ }
2754+
2755+ ACL_CHECK (aclDestroyTensor (acl_input_tensor));
26162756 }
26172757 }
26182758
26192759 // cast out
2620- int64_t * output_cast_ne = dst->ne ;
2621- size_t output_cast_nb[GGML_MAX_DIMS];
2622- output_cast_nb[0 ] = sizeof (uint16_t );
2623- for (int i = 1 ; i < GGML_MAX_DIMS; i++) {
2624- output_cast_nb[i] = output_cast_nb[i - 1 ] * output_cast_ne[i - 1 ];
2625- }
2760+ if (dst->type != GGML_TYPE_F16) {
2761+ int64_t * output_cast_ne = dst->ne ;
2762+ size_t output_cast_nb[GGML_MAX_DIMS];
2763+ output_cast_nb[0 ] = sizeof (uint16_t );
2764+ for (int i = 1 ; i < GGML_MAX_DIMS; i++) {
2765+ output_cast_nb[i] = output_cast_nb[i - 1 ] * output_cast_ne[i - 1 ];
2766+ }
26262767
2627- aclTensor* acl_output_tensor =
2628- ggml_cann_create_tensor (output_buffer, ACL_FLOAT16, output_elem_size,
2629- output_cast_ne, output_cast_nb, GGML_MAX_DIMS);
2630- aclTensor* acl_dst_tensor = ggml_cann_create_tensor (dst);
2631- aclnn_cast (ctx, acl_output_tensor, acl_dst_tensor, ACL_FLOAT);
2768+ aclTensor* acl_output_tensor = ggml_cann_create_tensor (
2769+ output_buffer,
2770+ ACL_FLOAT16,
2771+ output_elem_size, output_cast_ne, output_cast_nb, GGML_MAX_DIMS);
2772+ aclTensor* acl_dst_tensor = ggml_cann_create_tensor (dst);
2773+ aclnn_cast (ctx, acl_output_tensor, acl_dst_tensor, ggml_cann_type_mapping (dst->type ));
26322774
2633- ACL_CHECK (aclDestroyTensor (acl_output_tensor));
2634- ACL_CHECK (aclDestroyTensor (acl_dst_tensor));
2775+ ACL_CHECK (aclDestroyTensor (acl_output_tensor));
2776+ ACL_CHECK (aclDestroyTensor (acl_dst_tensor));
2777+ }
26352778}
26362779
26372780void ggml_cann_mul_mat (ggml_backend_cann_context& ctx, ggml_tensor* dst) {
0 commit comments