Skip to content

Commit c3f6678

Browse files
committed
sycl : temporary fix for performance regression
1 parent 9901068 commit c3f6678

File tree

1 file changed

+20
-2
lines changed

1 file changed

+20
-2
lines changed

ggml/src/ggml-sycl/ggml-sycl.cpp

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
#include "ggml-sycl/backend.hpp"
4040
#include "ggml-sycl/presets.hpp"
4141
#include "ggml-sycl/gemm.hpp"
42+
#include "ggml.h"
4243

4344
static bool g_sycl_loaded = false;
4445

@@ -3446,22 +3447,38 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor
34463447
if (ctx.stream()->get_backend() == sycl::backend::ext_oneapi_cuda)
34473448
use_dequantize_mul_mat_vec = use_dequantize_mul_mat_vec && !use_mul_mat_vec_q;
34483449

3450+
// printf("src0: %8d %8d %8d %8d\n", src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3]);
3451+
// printf(" %8d %8d %8d %8d\n", src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3]);
3452+
// printf("src1: %8d %8d %8d %8d\n", src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3]);
3453+
// printf(" %8d %8d %8d %8d\n", src1->nb[0], src1->nb[1], src1->nb[2], src1->nb[3]);
3454+
// printf("src0 is contiguous %d, transposed %d, permuted = %d, type = %s, name = %s\n", ggml_is_contiguous(src0), ggml_is_transposed(src0), ggml_is_permuted(src0), ggml_type_name(src0->type), src0->name);
3455+
// printf("src1 is contiguous %d, transposed %d, permuted = %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_is_permuted(src1), ggml_type_name(src1->type), src1->name);
3456+
3457+
3458+
34493459
if (!split && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) {
34503460
// KQ single-batch
3461+
// printf("MUL_MAT KQ single-batch\n");
34513462
ggml_sycl_mul_mat_vec_p021(ctx, src0, src1, dst);
34523463
} else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) {
34533464
// KQV single-batch
3465+
// printf("MUL_MAT KQV single-batch\n");
34543466
ggml_sycl_mul_mat_vec_nc(ctx, src0, src1, dst);
34553467
} else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
34563468
// KQ + KQV multi-batch
3469+
// printf("MUL_MAT KQ + KQV multi-batch\n");
34573470
ggml_sycl_mul_mat_batched_sycl(ctx, src0, src1, dst);
34583471
} else if (use_dequantize_mul_mat_vec) {
3472+
// printf("MUL_MAT dmmv\n");
34593473
ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_dequantize_mul_mat_vec, false);
34603474
} else if (use_mul_mat_vec_q) {
3475+
// printf("MUL_MAT mmvq\n");
34613476
ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_vec_q, true);
34623477
} else if (use_mul_mat_q) {
3478+
// printf("MUL_MAT mmq\n");
34633479
ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_q, true);
34643480
} else {
3481+
// printf("MUL_MAT ELSE\n");
34653482
ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_sycl, false);
34663483
}
34673484
}
@@ -4350,9 +4367,10 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
43504367
if (op->op == GGML_OP_MUL_MAT) {
43514368
a = op->src[0];
43524369
b = op->src[1];
4353-
if (ggml_is_permuted(a) || ggml_is_permuted(b)) {
4370+
if (ggml_is_permuted(a)) {
43544371
// TODO: fix like https://github.com/ggerganov/llama.cpp/pull/10021
4355-
return false;
4372+
if (a->nb[0] <= a->nb[1] && a->nb[3] <= a->nb[2]) return false; // 0,1,3,2 Unsupported
4373+
if (b->type != GGML_TYPE_F32) return false;
43564374
}
43574375
} else {
43584376
a = op->src[2];

0 commit comments

Comments
 (0)