@@ -2705,9 +2705,9 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons
27052705 " : converting src1 to fp16" );
27062706
27072707 // iterate tensor dims and find the slowest moving dim and stride
2708- int64_t last_dim=0 ;
2709- int64_t last_str=0 ;
2710- int64_t largest_str=0 ;
2708+ int last_dim=0 ;
2709+ int last_str=0 ;
2710+ size_t largest_str=0 ;
27112711 for (int i = 0 ; i< 4 ; i++){
27122712 // last stride is always the largest
27132713 if (src1->nb [i] == largest_str){
@@ -2783,7 +2783,7 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons
27832783 auto launch_gemm_for_batches = [&ctx, queue](const sycl::half *src0,
27842784 const sycl::half *src1, float *dst,
27852785 int64_t a0, int64_t a1, int64_t batcha,
2786- int64_t b0 , int64_t b1, int64_t batchb,
2786+ int64_t /* b0 */ , int64_t b1, int64_t batchb,
27872787 int64_t sa0, int64_t sa1, int64_t sa2,
27882788 int64_t sb0, int64_t sb1, int64_t sb2,
27892789 int64_t sd2) {
@@ -2832,14 +2832,26 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons
28322832 }
28332833 };
28342834
2835- bool cont_batches_a = nb02 * ne02 == nb03;
2836- bool cont_batches_b = nb12 * ne12 == nb13;
2837- if (cont_batches_a && cont_batches_b) {
2835+ const bool cont_batches_dim2_a = nb02 * ne02 == nb03;
2836+ const bool cont_batches_dim2_b = nb12 * ne12 == nb13;
2837+ const bool cont_batches_dim3_a = ne02 == 1 && nb02 * ne01 == nb03;
2838+ const bool cont_batches_dim3_b = ne12 == 1 && nb12 * ne11 == nb13;
2839+ if (cont_batches_dim2_a && cont_batches_dim2_b) {
2840+ // A batch is considered contiguous if the dimension 2 is not strided
28382841 int64_t batches0 = ne02 * ne03;
28392842 int64_t batches1 = ne12 * ne13;
28402843 launch_gemm_for_batches (src0_f16, src1_f16, dst_ddf, ne00, ne01, batches0,
28412844 ne10, ne11, batches1, str_a0, str_a1, str_a2, str_b0, str_b1,
28422845 str_b2, nb2 / sizeof (float ));
2846+ } else if (cont_batches_dim3_a && cont_batches_dim3_b) {
2847+ // This case is similar to the one above with the difference that only the batch in dimension 3 is used and the dimension 2 is of size 1.
2848+ int64_t batches0 = ne02 * ne03;
2849+ int64_t batches1 = ne12 * ne13;
2850+ int64_t str_a3 = nb03 / type_size_src0;
2851+ int64_t str_b3 = nb13 / type_size_src1;
2852+ launch_gemm_for_batches (src0_f16, src1_f16, dst_ddf, ne00, ne01, batches0,
2853+ ne10, ne11, batches1, str_a0, str_a1, str_a3, str_b0, str_b1,
2854+ str_b3, nb2 / sizeof (float ));
28432855 } else {
28442856 for (int64_t b_a = 0 ; b_a < ne03; b_a++) {
28452857 const sycl::half *src0_f16_shifted
@@ -4215,6 +4227,15 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
42154227 // FIXME: keep a list of supported types to avoid breaking the backend when a new type is added
42164228 return false ;
42174229 }
4230+ // TODO: The configuration below needs more work to be supported with oneDNN
4231+ if (ggml_is_permuted (a) && !ggml_is_contiguous (a) && a->ne [2 ] > 1 && a->ne [3 ] > 1 ) {
4232+ return false ;
4233+ }
4234+ // TODO: This specific configuration can fail with oneDNN and needs more debugging
4235+ if (!ggml_is_permuted (a) && ggml_is_permuted (b) && b->ne [2 ] > 1 && b->ne [3 ] > 1 &&
4236+ a->ne [0 ] > 128 && a->ne [2 ] == 1 && src0_type == GGML_TYPE_F16) {
4237+ return false ;
4238+ }
42184239 return true ;
42194240 }
42204241 case GGML_OP_OUT_PROD:
0 commit comments