Skip to content

Commit 716d318

Browse files
ikawrakowIwan Kawrakow
andauthored
Fix fused up+gate when mmq is not supported (#952)
Co-authored-by: Iwan Kawrakow <[email protected]>
1 parent 459bf58 commit 716d318

File tree

1 file changed

+17
-8
lines changed

1 file changed

+17
-8
lines changed

ggml/src/ggml-cuda.cu

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3010,16 +3010,25 @@ static void ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor
30103010
0, src0_2->ne[1], src1->ne[1], ne10_padded, stream);
30113011
CUDA_CHECK(cudaGetLastError());
30123012
} else {
3013-
quantize_mmq_q8_1_cuda((const float *)src1->data, src1_quantized.get(), src1->ne[0], src1->ne[1], 1, ne10_padded, src0_1->type, stream);
3014-
CUDA_CHECK(cudaGetLastError());
30153013

3016-
ggml_cuda_op_mul_mat_q(ctx, src0_1, src1, dst, (const char *)src0_1->data, nullptr, src1_quantized.get(), dst_up.get(),
3017-
0, src0_1->ne[1], src1->ne[1], ne10_padded, stream);
3018-
CUDA_CHECK(cudaGetLastError());
3014+
if (ggml_cuda_should_use_mmq(src0_1->type, ggml_cuda_info().devices[ctx.device].cc, src1->ne[1])) {
3015+
quantize_mmq_q8_1_cuda((const float *)src1->data, src1_quantized.get(), src1->ne[0], src1->ne[1], 1,
3016+
ne10_padded, src0_1->type, stream);
3017+
CUDA_CHECK(cudaGetLastError());
30193018

3020-
ggml_cuda_op_mul_mat_q(ctx, src0_2, src1, dst, (const char *)src0_2->data, nullptr, src1_quantized.get(), (float *)dst->data,
3021-
0, src0_1->ne[1], src1->ne[1], ne10_padded, stream);
3022-
CUDA_CHECK(cudaGetLastError());
3019+
ggml_cuda_op_mul_mat_q(ctx, src0_1, src1, dst, (const char *)src0_1->data, nullptr, src1_quantized.get(), dst_up.get(),
3020+
0, src0_1->ne[1], src1->ne[1], ne10_padded, stream);
3021+
CUDA_CHECK(cudaGetLastError());
3022+
3023+
ggml_cuda_op_mul_mat_q(ctx, src0_2, src1, dst, (const char *)src0_2->data, nullptr, src1_quantized.get(), (float *)dst->data,
3024+
0, src0_1->ne[1], src1->ne[1], ne10_padded, stream);
3025+
CUDA_CHECK(cudaGetLastError());
3026+
} else {
3027+
auto local_dst = *dst;
3028+
local_dst.data = dst_up.get();
3029+
ggml_cuda_mul_mat(ctx, src0_1, src1, &local_dst, nullptr, 0);
3030+
ggml_cuda_mul_mat(ctx, src0_2, src1, dst, nullptr, 0);
3031+
}
30233032
}
30243033

30253034
ggml_fused_mul_unary(ctx, (ggml_unary_op)dst->op_params[0], ggml_nelements(dst),

0 commit comments

Comments
 (0)