|
39 | 39 | #include "ggml-sycl/backend.hpp" |
40 | 40 | #include "ggml-sycl/presets.hpp" |
41 | 41 | #include "ggml-sycl/gemm.hpp" |
| 42 | +#include "ggml.h" |
42 | 43 |
|
43 | 44 | static bool g_sycl_loaded = false; |
44 | 45 |
|
@@ -3446,22 +3447,38 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor |
3446 | 3447 | if (ctx.stream()->get_backend() == sycl::backend::ext_oneapi_cuda) |
3447 | 3448 | use_dequantize_mul_mat_vec = use_dequantize_mul_mat_vec && !use_mul_mat_vec_q; |
3448 | 3449 |
|
| 3450 | + // printf("src0: %8d %8d %8d %8d\n", src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3]); |
| 3451 | + // printf(" %8d %8d %8d %8d\n", src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3]); |
| 3452 | + // printf("src1: %8d %8d %8d %8d\n", src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3]); |
| 3453 | + // printf(" %8d %8d %8d %8d\n", src1->nb[0], src1->nb[1], src1->nb[2], src1->nb[3]); |
| 3454 | + // printf("src0 is contiguous %d, transposed %d, permuted = %d, type = %s, name = %s\n", ggml_is_contiguous(src0), ggml_is_transposed(src0), ggml_is_permuted(src0), ggml_type_name(src0->type), src0->name); |
| 3455 | + // printf("src1 is contiguous %d, transposed %d, permuted = %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_is_permuted(src1), ggml_type_name(src1->type), src1->name); |
| 3456 | + |
| 3457 | + |
| 3458 | + |
3449 | 3459 | if (!split && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) { |
3450 | 3460 | // KQ single-batch |
| 3461 | + // printf("MUL_MAT KQ single-batch\n"); |
3451 | 3462 | ggml_sycl_mul_mat_vec_p021(ctx, src0, src1, dst); |
3452 | 3463 | } else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) { |
3453 | 3464 | // KQV single-batch |
| 3465 | + // printf("MUL_MAT KQV single-batch\n"); |
3454 | 3466 | ggml_sycl_mul_mat_vec_nc(ctx, src0, src1, dst); |
3455 | 3467 | } else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) { |
3456 | 3468 | // KQ + KQV multi-batch |
| 3469 | + // printf("MUL_MAT KQ + KQV multi-batch\n"); |
3457 | 3470 | ggml_sycl_mul_mat_batched_sycl(ctx, src0, src1, dst); |
3458 | 3471 | } else if (use_dequantize_mul_mat_vec) { |
| 3472 | + // printf("MUL_MAT dmmv\n"); |
3459 | 3473 | ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_dequantize_mul_mat_vec, false); |
3460 | 3474 | } else if (use_mul_mat_vec_q) { |
| 3475 | + // printf("MUL_MAT mmvq\n"); |
3461 | 3476 | ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_vec_q, true); |
3462 | 3477 | } else if (use_mul_mat_q) { |
| 3478 | + // printf("MUL_MAT mmq\n"); |
3463 | 3479 | ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_q, true); |
3464 | 3480 | } else { |
| 3481 | + // printf("MUL_MAT ELSE\n"); |
3465 | 3482 | ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_sycl, false); |
3466 | 3483 | } |
3467 | 3484 | } |
@@ -4350,9 +4367,10 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g |
4350 | 4367 | if (op->op == GGML_OP_MUL_MAT) { |
4351 | 4368 | a = op->src[0]; |
4352 | 4369 | b = op->src[1]; |
4353 | | - if (ggml_is_permuted(a) || ggml_is_permuted(b)) { |
| 4370 | + if (ggml_is_permuted(a)) { |
4354 | 4371 | // TODO: fix like https://github.com/ggerganov/llama.cpp/pull/10021 |
4355 | | - return false; |
| 4372 | + if (a->nb[0] <= a->nb[1] && a->nb[3] <= a->nb[2]) return false; // 0,1,3,2 Unsupported |
| 4373 | + if (b->type != GGML_TYPE_F32) return false; |
4356 | 4374 | } |
4357 | 4375 | } else { |
4358 | 4376 | a = op->src[2]; |
|
0 commit comments