@@ -2790,10 +2790,14 @@ static void ggml_cann_mul_mat_quant(ggml_backend_cann_context& ctx,
27902790                (char *)output_buffer + batch1 * output_stride, ACL_FLOAT16,
27912791                output_elem_size, output_ne, output_nb, 2 , ACL_FORMAT_ND,
27922792                output_ne_offset);
2793+             int64_t  antiquantGroupSize = 0 ;
2794+             if  (src0->ne [0 ] > QK8_0) {
2795+                 antiquantGroupSize = QK8_0;
2796+             }
27932797
27942798            ACL_CHECK (aclnnWeightQuantBatchMatmulV2GetWorkspaceSize (
27952799                acl_input_tensor, acl_weight_tensor, acl_scale_tensor, nullptr ,
2796-                 nullptr , nullptr , nullptr , QK8_0 , acl_output_tensor,
2800+                 nullptr , nullptr , nullptr , antiquantGroupSize , acl_output_tensor,
27972801                &workspaceSize, &executor));
27982802            if  (workspaceAddr == nullptr ) {
27992803                workspaceAddr = workspace_allocator.alloc (workspaceSize);
@@ -2833,7 +2837,7 @@ static void ggml_cann_mul_mat_quant(ggml_backend_cann_context& ctx,
28332837
28342838                ACL_CHECK (aclnnWeightQuantBatchMatmulV2GetWorkspaceSize (
28352839                    acl_input_tensor, acl_weight_tensor, acl_scale_tensor,
2836-                     nullptr , nullptr , nullptr , nullptr , QK8_0 ,
2840+                     nullptr , nullptr , nullptr , nullptr , antiquantGroupSize ,
28372841                    acl_output_tensor, &workspaceSize, &executor));
28382842                ACL_CHECK (aclnnWeightQuantBatchMatmulV2 (
28392843                    workspaceAddr, workspaceSize, executor, ctx.stream ()));
0 commit comments