@@ -3343,10 +3343,6 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx,
33433343 SYCL_CHECK (ggml_sycl_set_device (ctx.device ));
33443344 queue_ptr main_stream = ctx.stream ();;
33453345
3346- bool no_mixed_dtypes = main_stream->get_backend () == sycl::backend::ext_oneapi_cuda ||
3347- main_stream->get_backend () == sycl::backend::ext_oneapi_hip;
3348-
3349-
33503346 void * src0_ddq = src0->data ;
33513347 sycl::half *src0_as_f16 = (sycl::half *)src0_ddq;
33523348 float * src1_ddf = (float *) src1->data ;
@@ -3364,15 +3360,10 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx,
33643360 sycl::half *src1_f16 = src1->type == GGML_TYPE_F16 ? (sycl::half *)src1_ddf
33653361 : src1_f16_alloc.get ();
33663362
3367- ggml_sycl_pool_alloc<sycl::half> dst_f16 (ctx.pool ());
33683363 char * dst_t ;
33693364
33703365 dpct::library_data_t cu_compute_type = dpct::library_data_t ::real_float;
33713366 dpct::library_data_t cu_data_type = dpct::library_data_t ::real_float;
3372- if (no_mixed_dtypes) {
3373- cu_compute_type = dpct::library_data_t ::real_half;
3374- cu_data_type = dpct::library_data_t ::real_half;
3375- }
33763367
33773368 // dst strides
33783369 size_t nbd2 = dst->nb [2 ];
@@ -3381,26 +3372,10 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx,
33813372 const float alpha_f32 = 1 .0f ;
33823373 const float beta_f32 = 0 .0f ;
33833374
3384- const sycl::half alpha_f16 = 1 .0f ;
3385- const sycl::half beta_f16 = 0 .0f ;
3386-
33873375 const void * alpha = &alpha_f32;
33883376 const void * beta = &beta_f32;
3389- if (no_mixed_dtypes) {
3390- alpha = &alpha_f16;
3391- beta = &beta_f16;
3392- }
3393-
3394- // TODO: Renable (dst->op_params[0] =! GGML_PREC_DEFAULT) pathway
3395- // when oneMKL open source supports half, half, float, float: datatypes
33963377
33973378 dst_t = (char *) dst_ddf;
3398- if (no_mixed_dtypes) {
3399- dst_t = (char *) dst_f16.alloc (ne_dst);
3400-
3401- nbd2 /= sizeof (float ) / sizeof (sycl::half);
3402- nbd3 /= sizeof (float ) / sizeof (sycl::half);
3403- }
34043379
34053380 GGML_ASSERT (ne12 % ne02 == 0 );
34063381 GGML_ASSERT (ne13 % ne03 == 0 );
@@ -3462,11 +3437,6 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx,
34623437 (void **)(ptrs_dst.get () + 0 * ne23), cu_data_type, ne01, ne23,
34633438 cu_compute_type)));
34643439 }
3465-
3466- if (no_mixed_dtypes) {
3467- const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl (GGML_TYPE_F16);
3468- to_fp32_sycl (dst_f16.get (), dst_ddf, ne_dst, main_stream);
3469- }
34703440}
34713441catch (sycl::exception const &exc) {
34723442 std::cerr << exc.what () << " Exception caught at file:" << __FILE__
@@ -5069,9 +5039,6 @@ GGML_CALL static bool ggml_backend_sycl_supports_op(ggml_backend_t backend, cons
50695039
50705040 ggml_type a_type = a->type ;
50715041
5072- if (op->op == GGML_OP_MUL_MAT_ID || op->op == GGML_OP_MUL_MAT_ID){
5073- if (op->src [0 ]->type == GGML_TYPE_BF16) return false ;
5074- }
50755042
50765043 if (a_type == GGML_TYPE_IQ4_NL || a_type == GGML_TYPE_IQ4_XS ||
50775044 a_type == GGML_TYPE_IQ3_XXS || a_type == GGML_TYPE_IQ3_S ||
@@ -5082,10 +5049,12 @@ GGML_CALL static bool ggml_backend_sycl_supports_op(ggml_backend_t backend, cons
50825049 return false ;
50835050 }
50845051 }
5052+
50855053 ggml_type src0_type = op->src [0 ]->type ;
50865054 if (src0_type == GGML_TYPE_BF16) {
50875055 return false ;
50885056 }
5057+
50895058 return true ;
50905059 } break ;
50915060 case GGML_OP_GET_ROWS:
@@ -5133,7 +5102,8 @@ GGML_CALL static bool ggml_backend_sycl_supports_op(ggml_backend_t backend, cons
51335102 case GGML_OP_CONCAT:
51345103 {
51355104 ggml_type src0_type = op->src [0 ]->type ;
5136- return src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16;
5105+ int dim = op->op_params [0 ];
5106+ return ggml_is_contiguous (op->src [0 ]) && ggml_is_contiguous (op->src [1 ]) && src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16 && dim == 2 ;
51375107 } break ;
51385108 case GGML_OP_DUP:
51395109 case GGML_OP_NONE:
0 commit comments