@@ -2790,10 +2790,15 @@ static void ggml_cann_mul_mat_quant(ggml_backend_cann_context& ctx,
27902790 (char *)output_buffer + batch1 * output_stride, ACL_FLOAT16,
27912791 output_elem_size, output_ne, output_nb, 2 , ACL_FORMAT_ND,
27922792 output_ne_offset);
2793+
2794+ int64_t antiquantGroupSize = 0 ;
2795+ if (src0->ne [0 ] > QK8_0) {
2796+ antiquantGroupSize = QK8_0;
2797+ }
27932798
27942799 ACL_CHECK (aclnnWeightQuantBatchMatmulV2GetWorkspaceSize (
27952800 acl_input_tensor, acl_weight_tensor, acl_scale_tensor, nullptr ,
2796- nullptr , nullptr , nullptr , QK8_0 , acl_output_tensor,
2801+ nullptr , nullptr , nullptr , antiquantGroupSize , acl_output_tensor,
27972802 &workspaceSize, &executor));
27982803 if (workspaceAddr == nullptr ) {
27992804 workspaceAddr = workspace_allocator.alloc (workspaceSize);
@@ -2833,7 +2838,7 @@ static void ggml_cann_mul_mat_quant(ggml_backend_cann_context& ctx,
28332838
28342839 ACL_CHECK (aclnnWeightQuantBatchMatmulV2GetWorkspaceSize (
28352840 acl_input_tensor, acl_weight_tensor, acl_scale_tensor,
2836- nullptr , nullptr , nullptr , nullptr , QK8_0 ,
2841+ nullptr , nullptr , nullptr , nullptr , antiquantGroupSize ,
28372842 acl_output_tensor, &workspaceSize, &executor));
28382843 ACL_CHECK (aclnnWeightQuantBatchMatmulV2 (
28392844 workspaceAddr, workspaceSize, executor, ctx.stream ()));
0 commit comments