Skip to content

Commit 4605828

Browse files
committed
fix use_mul_mat_vec_f for mul_mat_id
1 parent 9a9cd33 commit 4605828

File tree

2 files changed

+4
-2
lines changed

2 files changed

+4
-2
lines changed

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2063,12 +2063,14 @@ static bool ggml_cuda_should_fuse_mul_mat_vec_f(const ggml_tensor * tensor) {
20632063
ggml_tensor * src1 = tensor->src[1];
20642064
const ggml_tensor * dst = tensor;
20652065

2066+
const bool is_mul_mat_id = tensor->op == GGML_OP_MUL_MAT_ID;
2067+
20662068
bool use_mul_mat_vec_f =
20672069
(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16) &&
20682070
src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;
20692071

20702072
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
2071-
use_mul_mat_vec_f = use_mul_mat_vec_f && ggml_cuda_should_use_mmvf(src0->type, cc, src0->ne, src1->ne[1]);
2073+
use_mul_mat_vec_f = use_mul_mat_vec_f && ggml_cuda_should_use_mmvf(src0->type, cc, src0->ne, is_mul_mat_id ? src1->ne[2]: src1->ne[1]);
20722074

20732075
if (tensor->op == GGML_OP_MUL_MAT_ID) {
20742076
use_mul_mat_vec_f = use_mul_mat_vec_f && dst->ne[2] == 1;

tests/test-backend-ops.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4690,7 +4690,7 @@ struct test_fused_ffn_gate : public test_case {
46904690
}
46914691

46924692
double max_nmse_err() override {
4693-
return 5e-4;
4693+
return 1e-3;
46944694
}
46954695
};
46964696

0 commit comments

Comments
 (0)