@@ -2770,22 +2770,42 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx,
27702770    //  broadcast factors
27712771    const  auto  r2 = ne12/ne02;
27722772    const  auto  r3 = ne13/ne03;
2773-     const  auto  ne23 = ne12*ne13;
27742773
2775-     if  (r2 == 1  && r3 == 1  && ggml_is_contiguous_2 (src0) && ggml_is_contiguous_2 (src1)) {
2776-         //  there is no broadcast and src0, src1 are contiguous across dims 2, 3
27772774#if  GGML_SYCL_DNNL
2778-         DnnlGemmWrapper::gemm (ctx, ne11, ne01, ne10,
2775+     if  (r2 == 1  && r3 == 1 ) {
2776+         DnnlGemmWrapper::gemm (ctx, ne11,ne01, ne10,
27792777            src1_f16, DnnlGemmWrapper::to_dt<sycl::half>(), nb11/nb10, 1 , nb12/nb10,
27802778            src0_as_f16, DnnlGemmWrapper::to_dt<sycl::half>(), 1 , nb01/nb00, nb02/nb00,
2781-             dst_t , DnnlGemmWrapper::to_dt<float >(), main_stream, ne23);
2779+             dst_t , DnnlGemmWrapper::to_dt<float >(), main_stream, ne12*ne13, ne02 * ne03);
2780+     } else  {
2781+         //  nb1X_scaled is in bytes as if matrix 1 type would be sycl::half (it may be already such or it may be 4-bytes)
2782+         const  auto  nb12_scaled = src1->type  == GGML_TYPE_F16 ? nb12 : nb12 / 2 ;
2783+         const  auto  nb13_scaled = src1->type  == GGML_TYPE_F16 ? nb13 : nb13 / 2 ;
2784+ 
2785+         //  iterate over batches from smaller set of matrices (matrix 0)
2786+         for  (int64_t  ie02 = 0 ; ie02 < ne02; ++ie02) {
2787+             for  (int64_t  ie03 = 0 ; ie03 < ne03; ++ie03) {
2788+                 const  sycl::half* src0_f16_shifted = src0_as_f16 + ((ie02*nb02 + ie03*nb03)/2 ); //  div2 cuz nb is in bytes and pointer is in f16 (2 bytes)
2789+                 const  sycl::half* src1_f16_shifted = src1_f16 + ((ie02*nb12_scaled*r2 + ie03*nb13_scaled*r3)/2 );
2790+                 float * dst_shifted = dst_ddf + ((ie02*nb2*r2 + ie03*nb3*r3)/4 ); //  div4 cuz nb is in bytes and pointer is float (4 bytes)
2791+ 
2792+                 DnnlGemmWrapper::gemm (ctx, ne11,ne01, ne10,
2793+                 src1_f16_shifted, DnnlGemmWrapper::to_dt<sycl::half>(), nb11/nb10, 1 , nb12/nb10,
2794+                 src0_f16_shifted, DnnlGemmWrapper::to_dt<sycl::half>(), 1 , nb01/nb00, nb02/nb00,
2795+                 dst_shifted, DnnlGemmWrapper::to_dt<float >(), main_stream, r2 * r3, 1 );
2796+             }
2797+         }
2798+     }
27822799#else 
2800+     const  auto  ne23 = ne12*ne13;
2801+     if  (r2 == 1  && r3 == 1  && ggml_is_contiguous_2 (src0) && ggml_is_contiguous_2 (src1)) {
2802+         //  there is no broadcast and src0, src1 are contiguous across dims 2, 3
2803+ 
27832804        SYCL_CHECK (CHECK_TRY_ERROR (dpct::gemm_batch (
27842805            *main_stream, oneapi::math::transpose::trans, oneapi::math::transpose::nontrans, ne01, ne11, ne10, alpha,
27852806            (const  char  *) src0_as_f16, dpct::library_data_t ::real_half, nb01 / nb00, nb02 / nb00,
27862807            (const  char  *) src1_f16, dpct::library_data_t ::real_half, nb11 / nb10, nb12 / nb10, beta, (char  *) dst_t ,
27872808            cu_data_type, ne01, nb2 / nb0, ne23, cu_compute_type)));
2788- #endif 
27892809    } else  {
27902810
27912811        ggml_sycl_pool_alloc<const  void  *> ptrs_src (ctx.pool (), 2 *ne23);
@@ -2824,6 +2844,7 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx,
28242844            (const  void  **) (ptrs_src.get () + 1  * ne23), dpct::library_data_t ::real_half, nb11 / nb10, beta,
28252845            (void  **) (ptrs_dst.get () + 0  * ne23), cu_data_type, ne01, ne23, cu_compute_type, matrix_info.get ())));
28262846    }
2847+ #endif 
28272848}
28282849catch  (sycl::exception const  &exc) {
28292850  std::cerr << exc.what () << " Exception caught at file:"   << __FILE__
0 commit comments