@@ -1982,7 +1982,7 @@ inline void ggml_sycl_op_mul_mat_sycl(
19821982
19831983 const int64_t ne00 = src0->ne [0 ];
19841984 const int64_t ne10 = src1->ne [0 ];
1985-
1985+ GGML_ASSERT (ne00 == ne10);
19861986
19871987 const int64_t row_diff = row_high - row_low;
19881988
@@ -2727,10 +2727,10 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx,
27272727 GGML_ASSERT (!ggml_is_transposed (src1));
27282728 GGML_ASSERT (!ggml_backend_buffer_is_sycl_split (src0->buffer ));
27292729 GGML_ASSERT (src0->type == GGML_TYPE_F16);
2730+ GGML_ASSERT (dst->type == GGML_TYPE_F32);
27302731
27312732 GGML_TENSOR_BINARY_OP_LOCALS
27322733
2733-
27342734 SYCL_CHECK (ggml_sycl_set_device (ctx.device ));
27352735 queue_ptr main_stream = ctx.stream ();;
27362736
@@ -2751,39 +2751,45 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx,
27512751 sycl::half *src1_f16 = src1->type == GGML_TYPE_F16 ? (sycl::half *)src1_ddf
27522752 : src1_f16_alloc.get ();
27532753
2754- char * dst_t ;
2755-
2756- dpct::library_data_t cu_compute_type = dpct::library_data_t ::real_float;
2757- dpct::library_data_t cu_data_type = dpct::library_data_t ::real_float;
2758-
2759- // dst strides
2760- size_t nbd2 = dst->nb [2 ];
2761- size_t nbd3 = dst->nb [3 ];
2754+ const dpct::library_data_t cu_compute_type = dpct::library_data_t ::real_float;
2755+ const dpct::library_data_t cu_data_type = dpct::library_data_t ::real_float;
27622756
27632757 const float alpha_f32 = 1 .0f ;
27642758 const float beta_f32 = 0 .0f ;
27652759
27662760 const void * alpha = &alpha_f32;
27672761 const void * beta = &beta_f32;
27682762
2769- dst_t = (char *) dst_ddf;
2763+ char * dst_t = (char *) dst_ddf;
27702764
27712765 GGML_ASSERT (ne12 % ne02 == 0 );
27722766 GGML_ASSERT (ne13 % ne03 == 0 );
2767+ GGML_ASSERT (ne01 == static_cast <int64_t >(nb1/nb0));
2768+ GGML_ASSERT (ne10 == ne00);
27732769
27742770 // broadcast factors
2775- const int64_t r2 = ne12/ne02;
2776- const int64_t r3 = ne13/ne03;
2771+ const auto r2 = ne12/ne02;
2772+ const auto r3 = ne13/ne03;
2773+ const auto ne23 = ne12*ne13;
27772774
27782775 if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2 (src0) && ggml_is_contiguous_2 (src1)) {
27792776 // there is no broadcast and src0, src1 are contiguous across dims 2, 3
2777+ #ifdef GGML_SYCL_DNNL
2778+ // TODO: use strided dnnl::memory::desc ctor in row_gemm to relax below assertions
2779+ GGML_ASSERT (nb11/nb10 == ne10);
2780+ GGML_ASSERT (nb01/nb00 == ne00);
2781+
2782+ DnnlGemmWrapper::row_gemm (ctx, false , true , ne11, ne01, ne10, src1_f16,
2783+ DnnlGemmWrapper::to_dt<sycl::half>(), src0_as_f16, DnnlGemmWrapper::to_dt<sycl::half>(),
2784+ dst_t , DnnlGemmWrapper::to_dt<float >(), main_stream, ne23);
2785+ #else
27802786 SYCL_CHECK (CHECK_TRY_ERROR (dpct::gemm_batch (
27812787 *main_stream, oneapi::math::transpose::trans, oneapi::math::transpose::nontrans, ne01, ne11, ne10, alpha,
27822788 (const char *) src0_as_f16, dpct::library_data_t ::real_half, nb01 / nb00, nb02 / nb00,
27832789 (const char *) src1_f16, dpct::library_data_t ::real_half, nb11 / nb10, nb12 / nb10, beta, (char *) dst_t ,
2784- cu_data_type, ne01, nb2 / nb0, ne12 * ne13, cu_compute_type)));
2790+ cu_data_type, ne01, nb2 / nb0, ne23, cu_compute_type)));
2791+ #endif
27852792 } else {
2786- const int ne23 = ne12*ne13;
27872793
27882794 ggml_sycl_pool_alloc<const void *> ptrs_src (ctx.pool (), 2 *ne23);
27892795 ggml_sycl_pool_alloc< void *> ptrs_dst (ctx.pool (), 1 *ne23);
@@ -2811,7 +2817,7 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx,
28112817 dst_t , ptrs_src_get,
28122818 ptrs_dst_get, ne12, ne13, ne23,
28132819 nb02, nb03, nb12_scaled, nb13_scaled,
2814- nbd2, nbd3 , r2, r3, item_ct1);
2820+ nb2, nb3 , r2, r3, item_ct1);
28152821 });
28162822 });
28172823 }
@@ -3651,7 +3657,8 @@ static ggml_status ggml_backend_sycl_graph_compute(ggml_backend_t backend, ggml_
36513657 return GGML_STATUS_SUCCESS;
36523658 }
36533659
3654- sycl_ex::command_graph model_sycl_graph (*(sycl_ctx->stream ()));
3660+ sycl_ex::command_graph model_sycl_graph (*(sycl_ctx->stream ()), {sycl_ex::property::graph::assume_buffer_outlives_graph{}});
3661+
36553662 model_sycl_graph.begin_recording (*(sycl_ctx->stream ()));
36563663 ggml_backend_sycl_graph_compute_impl (sycl_ctx, cgraph);
36573664 model_sycl_graph.end_recording ();
0 commit comments