@@ -2790,10 +2790,14 @@ static void ggml_cann_mul_mat_quant(ggml_backend_cann_context& ctx,
27902790 (char *)output_buffer + batch1 * output_stride, ACL_FLOAT16,
27912791 output_elem_size, output_ne, output_nb, 2 , ACL_FORMAT_ND,
27922792 output_ne_offset);
2793+ int64_t antiquantGroupSize = 0 ;
2794+ if (src0->ne [0 ] > QK8_0) {
2795+ antiquantGroupSize = QK8_0;
2796+ }
27932797
27942798 ACL_CHECK (aclnnWeightQuantBatchMatmulV2GetWorkspaceSize (
27952799 acl_input_tensor, acl_weight_tensor, acl_scale_tensor, nullptr ,
2796- nullptr , nullptr , nullptr , QK8_0 , acl_output_tensor,
2800+ nullptr , nullptr , nullptr , antiquantGroupSize , acl_output_tensor,
27972801 &workspaceSize, &executor));
27982802 if (workspaceAddr == nullptr ) {
27992803 workspaceAddr = workspace_allocator.alloc (workspaceSize);
@@ -2833,7 +2837,7 @@ static void ggml_cann_mul_mat_quant(ggml_backend_cann_context& ctx,
28332837
28342838 ACL_CHECK (aclnnWeightQuantBatchMatmulV2GetWorkspaceSize (
28352839 acl_input_tensor, acl_weight_tensor, acl_scale_tensor,
2836- nullptr , nullptr , nullptr , nullptr , QK8_0 ,
2840+ nullptr , nullptr , nullptr , nullptr , antiquantGroupSize ,
28372841 acl_output_tensor, &workspaceSize, &executor));
28382842 ACL_CHECK (aclnnWeightQuantBatchMatmulV2 (
28392843 workspaceAddr, workspaceSize, executor, ctx.stream ()));
0 commit comments