@@ -1546,8 +1546,9 @@ static void mul_mat_p021_f16_f32(
15461546
15471547static void mul_mat_vec_nc_f16_f32 ( // nc == non-contiguous
15481548 const void * __restrict__ vx, const float * __restrict__ y, float * __restrict__ dst, const int ncols_x, const int nrows_x,
1549- const int row_stride_x, const int channel_stride_x, const int channel_x_divisor,
1549+ const int row_stride_x, const int channel_stride_x,const int channel_stride_y, const int channel_x_divisor,
15501550 const sycl::nd_item<3 > &item_ct1) {
1551+
15511552
15521553 const sycl::half *x = (const sycl::half *)vx;
15531554
@@ -1557,7 +1558,6 @@ static void mul_mat_vec_nc_f16_f32( // nc == non-contiguous
15571558 item_ct1.get_local_id (0 );
15581559 const int channel_x = channel / channel_x_divisor;
15591560
1560- const int nrows_y = ncols_x;
15611561 const int nrows_dst = nrows_x;
15621562 const int row_dst = row_x;
15631563
@@ -1576,7 +1576,7 @@ static void mul_mat_vec_nc_f16_f32( // nc == non-contiguous
15761576 const int row_y = col_x;
15771577
15781578 const int ix = channel_x*channel_stride_x + row_x*row_stride_x + col_x;
1579- const int iy = channel*nrows_y + row_y;
1579+ const int iy = channel * channel_stride_y + row_y;
15801580
15811581 const float xi =
15821582 sycl::vec<sycl::half, 1 >(x[ix])
@@ -1823,7 +1823,7 @@ static void ggml_mul_mat_p021_f16_f32_sycl(const void *vx, const float *y,
18231823static void ggml_mul_mat_vec_nc_f16_f32_sycl (
18241824 const void *vx, const float *y, float *dst, const int ncols_x,
18251825 const int nrows_x, const int row_stride_x, const int nchannels_x,
1826- const int nchannels_y, const int channel_stride_x, queue_ptr stream) {
1826+ const int nchannels_y, const int channel_stride_x, const int channel_stride_y, queue_ptr stream) {
18271827
18281828 const sycl::range<3 > block_nums (nchannels_y, nrows_x, 1 );
18291829 const sycl::range<3 > block_dims (1 , 1 , WARP_SIZE);
@@ -1835,7 +1835,7 @@ static void ggml_mul_mat_vec_nc_f16_f32_sycl(
18351835 sycl::nd_range<3 >(block_nums * block_dims, block_dims),
18361836 [=](sycl::nd_item<3 > item_ct1) [[sycl::reqd_sub_group_size (WARP_SIZE)]] {
18371837 mul_mat_vec_nc_f16_f32 (vx, y, dst, ncols_x, nrows_x,
1838- row_stride_x, channel_stride_x,
1838+ row_stride_x, channel_stride_x, channel_stride_y,
18391839 nchannels_y / nchannels_x, item_ct1);
18401840 });
18411841 }
@@ -2124,8 +2124,8 @@ inline void ggml_sycl_op_mul_mat_sycl(
21242124
21252125#if GGML_SYCL_DNNL
21262126 if (!g_ggml_sycl_disable_dnn) {
2127- DnnlGemmWrapper::row_gemm (ctx, src1_ncols, row_diff , ne10, src1_ptr ,
2128- DnnlGemmWrapper::to_dt<sycl::half>(), src0_ptr , DnnlGemmWrapper::to_dt<sycl::half>(),
2127+ DnnlGemmWrapper::row_gemm (ctx,row_diff, src1_ncols , ne10, src0_ptr ,
2128+ DnnlGemmWrapper::to_dt<sycl::half>(), src1_ptr , DnnlGemmWrapper::to_dt<sycl::half>(),
21292129 dst_dd_i, DnnlGemmWrapper::to_dt<float >(), stream);
21302130 }
21312131 else
@@ -2171,8 +2171,8 @@ inline void ggml_sycl_op_mul_mat_sycl(
21712171
21722172#if GGML_SYCL_DNNL
21732173 if (!g_ggml_sycl_disable_dnn) {
2174- DnnlGemmWrapper::row_gemm (ctx, src1_ncols, row_diff , ne10, src1_ddf1_i ,
2175- DnnlGemmWrapper::to_dt<float >(), src0_ddf_i , DnnlGemmWrapper::to_dt<float >(),
2174+ DnnlGemmWrapper::row_gemm (ctx, row_diff, src1_ncols , ne10, src0_ddf_i ,
2175+ DnnlGemmWrapper::to_dt<float >(), src1_ddf1_i , DnnlGemmWrapper::to_dt<float >(),
21762176 dst_dd_i, DnnlGemmWrapper::to_dt<float >(), stream);
21772177 }
21782178 else
@@ -2776,6 +2776,7 @@ static void ggml_sycl_mul_mat_vec_nc(ggml_backend_sycl_context & ctx, const ggml
27762776 const int64_t nb02 = src0->nb [2 ];
27772777
27782778 const int64_t ne12 = src1->ne [2 ];
2779+ const int64_t nb11 = src1->nb [1 ];
27792780
27802781 SYCL_CHECK (ggml_sycl_set_device (ctx.device ));
27812782 queue_ptr main_stream = ctx.stream ();
@@ -2786,8 +2787,9 @@ static void ggml_sycl_mul_mat_vec_nc(ggml_backend_sycl_context & ctx, const ggml
27862787
27872788 const int64_t row_stride_x = nb01 / sizeof (sycl::half);
27882789 const int64_t channel_stride_x = nb02 / sizeof (sycl::half);
2790+ const int64_t channel_stride_y = nb11 / sizeof (float );
27892791
2790- ggml_mul_mat_vec_nc_f16_f32_sycl (src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, ne12, channel_stride_x, main_stream);
2792+ ggml_mul_mat_vec_nc_f16_f32_sycl (src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, ne12, channel_stride_x,channel_stride_y, main_stream);
27912793}
27922794catch (sycl::exception const &exc) {
27932795 std::cerr << exc.what () << " Exception caught at file:" << __FILE__
@@ -2841,8 +2843,8 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons
28412843 float * dst_ddf = static_cast <float *>(dst->data );
28422844
28432845 const sycl::half * src1_f16 = static_cast <const sycl::half *>(src1->data );
2846+ const size_t type_size_src0 = ggml_type_size (src0->type );
28442847 const size_t type_size_src1 = ggml_type_size (src1->type );
2845- GGML_ASSERT (nb10 == type_size_src1);
28462848
28472849 // SRC1 strides
28482850 int64_t s11 = nb11 / type_size_src1;
@@ -2854,11 +2856,33 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons
28542856 if (src1->type != GGML_TYPE_F16) {
28552857 scope_op_debug_print scope_dbg_print (__func__, " /to_fp16_nc_sycl" , dst, /* num_src=*/ 2 ,
28562858 " : converting src1 to fp16" );
2857- const to_fp16_nc_sycl_t to_fp16_nc_sycl = get_to_fp16_nc_sycl (src1->type );
2858- GGML_ASSERT (to_fp16_nc_sycl != nullptr );
2859- const int64_t ne_src1 = ggml_nelements (src1);
2860- src1_f16_alloc.alloc (ne_src1);
2861- to_fp16_nc_sycl (src1_f16, src1_f16_alloc.get (), ne10, ne11, ne12, ne13, s11, s12, s13, queue);
2859+
2860+
2861+ // iterate tensor dims and find the slowest moving dim and stride
2862+ int64_t last_dim=0 ;
2863+ int64_t last_str=0 ;
2864+ int64_t largest_str=0 ;
2865+ for (int i = 0 ; i< 4 ; i++){
2866+ // last stride is always the largest
2867+ if (src1->nb [i] == largest_str){
2868+ if (src1->ne [last_dim] == 1 ){
2869+ last_str = i;
2870+ last_dim = i;
2871+ }
2872+ }
2873+ if (src1->nb [i] > largest_str){
2874+ largest_str = src1->nb [i];
2875+ last_str = i;
2876+ last_dim = i;
2877+ }
2878+
2879+ }
2880+ const int64_t ne_src1 = src1->nb [last_str] * src1->ne [last_dim] / type_size_src1;
2881+ src1_f16_alloc.alloc (ne_src1);
2882+
2883+ const to_fp16_sycl_t to_fp16_sycl = ggml_get_to_fp16_sycl (src1->type , dst);
2884+ GGML_ASSERT (to_fp16_sycl != nullptr );
2885+ to_fp16_sycl (src1_f16, src1_f16_alloc.get (), ne_src1, queue);
28622886
28632887 src1_f16 = src1_f16_alloc.get ();
28642888 s11 = ne10;
@@ -2892,38 +2916,89 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons
28922916
28932917#if GGML_SYCL_DNNL
28942918 if (!g_ggml_sycl_disable_dnn) {
2895- auto dnn_gemm = [&ctx, queue, ne11, ne01, ne10, nb00, nb01, nb02, s11, s12]
2896- (const sycl::half* src1, const sycl::half* src0, float * dst, const dnnl_dim_t batches_a, const dnnl_dim_t batches_b) {
2897-
2898- DnnlGemmWrapper::gemm (ctx, ne11,ne01, ne10,
2899- src1, DnnlGemmWrapper::to_dt<sycl::half>(), s11, 1 , s12,
2900- src0, DnnlGemmWrapper::to_dt<sycl::half>(), 1 , nb01/nb00, nb02/nb00,
2901- dst, DnnlGemmWrapper::to_dt<float >(), queue, batches_a, batches_b);
2902- };
2903-
2904- if (r2 == 1 && r3 == 1 ) {
2905- if (ggml_is_contiguous_2 (src0) && ggml_is_contiguous_2 (src1)) {
2906- dnn_gemm (src1_f16, src0_f16, dst_ddf, ne12*ne13, ne02 * ne03);
2907- }
2908- else {
2909- for (int64_t ie03 = 0 ; ie03 < ne03; ++ie03) {
2910- const sycl::half* src0_f16_shifted = src0_f16 + ((ie03*nb03)/sizeof (sycl::half)); // nb is in bytes
2911- const sycl::half* src1_f16_shifted = src1_f16 + ie03*s13;
2912- float * dst_shifted = dst_ddf + ((ie03*nb3)/sizeof (float ));
2913- dnn_gemm (src1_f16_shifted, src0_f16_shifted, dst_shifted, ne12, ne02);
2919+ int64_t str_a0 = nb00 / type_size_src0;
2920+ int64_t str_a1 = nb01 / type_size_src0;
2921+ int64_t str_a2 = nb02 / type_size_src0;
2922+
2923+ int64_t str_b0 = nb10 / type_size_src1;
2924+ int64_t str_b1 = nb11 / type_size_src1;
2925+ int64_t str_b2 = nb12 / type_size_src1;
2926+
2927+ auto launch_gemm_for_batches = [&ctx, queue](const sycl::half *src0,
2928+ const sycl::half *src1, float *dst,
2929+ int64_t a0, int64_t a1, int64_t batcha,
2930+ int64_t b0, int64_t b1, int64_t batchb,
2931+ int64_t sa0, int64_t sa1, int64_t sa2,
2932+ int64_t sb0, int64_t sb1, int64_t sb2,
2933+ int64_t sd2) {
2934+ bool supported_broadcast = batchb == batcha ? true
2935+ : batchb == 1 || batcha == 1 ? true
2936+ : false ;
2937+ if (supported_broadcast) {
2938+ DnnlGemmWrapper::gemm (ctx, a1, b1, a0, src0,
2939+ DnnlGemmWrapper::to_dt<sycl::half>(), sa0, sa1, sa2, src1,
2940+ DnnlGemmWrapper::to_dt<sycl::half>(), sb0, sb1, sb2, dst,
2941+ DnnlGemmWrapper::to_dt<float >(), queue, batcha, batchb);
2942+ } else {
2943+ // iterate over batches from smaller set of matrices (matrix 0)
2944+ int64_t batches0 = batcha;
2945+ int64_t batches1 = batchb;
2946+
2947+ if (batches0 > batches1) {
2948+ int64_t num_mul_mats = batches1;
2949+ int64_t sub_batch = batches0 / num_mul_mats;
2950+ // src0 is batched and bigger, shift and multiply with src1
2951+ for (int64_t i0 = 0 ; i0 < num_mul_mats; i0++) {
2952+ const sycl::half *src0_shifted = src0 + (sa2 * i0 * sub_batch);
2953+ const sycl::half *src1_shifted = src1 + (sb2 * i0);
2954+ float *dst_shifted = dst + (sd2 * i0 * sub_batch);
2955+ DnnlGemmWrapper::gemm (ctx, a1, b1, a0, src0_shifted,
2956+ DnnlGemmWrapper::to_dt<sycl::half>(), sa0, sa1, sa2,
2957+ src1_shifted, DnnlGemmWrapper::to_dt<sycl::half>(), sb0,
2958+ sb1, sb2, dst_shifted, DnnlGemmWrapper::to_dt<float >(),
2959+ queue, sub_batch, 1 );
2960+ }
2961+ } else {
2962+ int64_t num_mul_mats = batches0;
2963+ int64_t sub_batch = batches1 / num_mul_mats;
2964+ // src1 is batched and bigger, shift and multiply with src0
2965+ for (int64_t i1 = 0 ; i1 < num_mul_mats; i1++) {
2966+ const sycl::half *src0_shifted = src0 + (sa2 * i1);
2967+ const sycl::half *src1_shifted = src1 + (sb2 * i1 * sub_batch);
2968+ float *dst_shifted = dst + (sd2 * i1 * sub_batch);
2969+ DnnlGemmWrapper::gemm (ctx, a1, b1, a0, src0_shifted,
2970+ DnnlGemmWrapper::to_dt<sycl::half>(), sa0, sa1, sa2,
2971+ src1_shifted, DnnlGemmWrapper::to_dt<sycl::half>(), sb0,
2972+ sb1, sb2, dst_shifted, DnnlGemmWrapper::to_dt<float >(),
2973+ queue, 1 , sub_batch);
2974+ }
2975+ }
29142976 }
2915- }
2916- } else {
2917- // iterate over batches from smaller set of matrices (matrix 0)
2918- for (int64_t ie02 = 0 ; ie02 < ne02; ++ie02) {
2919- for (int64_t ie03 = 0 ; ie03 < ne03; ++ie03) {
2920- const sycl::half* src0_f16_shifted = src0_f16 + ((ie02*nb02 + ie03*nb03)/sizeof (sycl::half));
2921- const sycl::half* src1_f16_shifted = src1_f16 + ie02*s12*r2 + ie03*s13*r3;
2922- float * dst_shifted = dst_ddf + ((ie02*nb2*r2 + ie03*nb3*r3)/sizeof (float ));
2923- dnn_gemm (src1_f16_shifted, src0_f16_shifted, dst_shifted, r2*r3, 1 );
2977+ };
2978+
2979+ bool cont_batches_a = nb02 * ne02 == nb03;
2980+ bool cont_batches_b = nb12 * ne12 == nb13;
2981+ if (cont_batches_a && cont_batches_b) {
2982+ int64_t batches0 = ne02 * ne03;
2983+ int64_t batches1 = ne12 * ne13;
2984+ launch_gemm_for_batches (src0_f16, src1_f16, dst_ddf, ne00, ne01, batches0,
2985+ ne10, ne11, batches1, str_a0, str_a1, str_a2, str_b0, str_b1,
2986+ str_b2, nb2 / sizeof (float ));
2987+ } else {
2988+ for (int64_t b_a = 0 ; b_a < ne03; b_a++) {
2989+ const sycl::half *src0_f16_shifted
2990+ = src0_f16 + (nb03 * b_a / type_size_src0);
2991+ const sycl::half *src1_f16_shifted
2992+ = src1_f16 + (nb13 * b_a / type_size_src1);
2993+ float *dst_shifted = dst_ddf + (nb3 * b_a / sizeof (float ));
2994+ int64_t batches0 = ne02;
2995+ int64_t batches1 = ne12;
2996+ launch_gemm_for_batches (src0_f16_shifted, src1_f16_shifted, dst_shifted,
2997+ ne00, ne01, batches0, ne10, ne11, batches1, str_a0, str_a1,
2998+ str_a2, str_b0, str_b1, str_b2, nb2 / sizeof (float ));
29242999 }
29253000 }
2926- }
3001+
29273002 }
29283003 else
29293004#endif
@@ -3263,10 +3338,10 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor
32633338 // The kernel from the if path is faster for that specific case, but does not support all mul mats.
32643339 ggml_sycl_mul_mat_batched_sycl (ctx, src0, src1, dst);
32653340 }
3266- } else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous (src0) && ggml_is_contiguous (src1) && !ggml_is_transposed (src1) && src1->ne [1 ] == 1 ) {
3341+ } else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous (src0) && !ggml_is_transposed (src1) && src1->ne [1 ] == 1 ) {
32673342 // KQV single-batch
32683343 ggml_sycl_mul_mat_vec_nc (ctx, src0, src1, dst);
3269- } else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_transposed (src0) && !ggml_is_transposed (src1) && src1->ne [2 ]* src1->ne [3 ] > 1 ) {
3344+ } else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_transposed (src0) && !ggml_is_transposed (src1) && src1->ne [2 ] * src1->ne [3 ] > 1 ) {
32703345 // KQ + KQV multi-batch
32713346 ggml_sycl_mul_mat_batched_sycl (ctx, src0, src1, dst);
32723347 } else if (use_dequantize_mul_mat_vec) {
0 commit comments