@@ -224,7 +224,7 @@ static void soft_max_f32_sycl(const float * x, const T * mask,
224224 }
225225}
226226
227- void ggml_sycl_op_soft_max (ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
227+ static void ggml_sycl_op_soft_max (ggml_backend_sycl_context & ctx, ggml_tensor * dst) try {
228228
229229 GGML_ASSERT (dst->src [0 ]->type == GGML_TYPE_F32);
230230 GGML_ASSERT ( dst->type == GGML_TYPE_F32);
@@ -249,13 +249,26 @@ void ggml_sycl_op_soft_max(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
249249
250250 if (dst->src [1 ] && dst->src [1 ]->type == GGML_TYPE_F16) {
251251 const sycl::half * src1_dd = static_cast <sycl::half *>(dst->src [1 ]->data );
252+ GGML_SYCL_DEBUG (" %s: Mask precision: F16\n " , __func__);
252253 soft_max_f32_sycl<sycl::half>(src0_dd, src1_dd, dst_dd, ne00, nrows_x, nrows_y, scale, max_bias,
253254 main_stream, ctx.device );
254255 } else if (dst->src [1 ] && dst->src [1 ]->type == GGML_TYPE_F32) {
255256 const float * src1_dd = static_cast <const float *>(dst->src [1 ]->data );
257+ GGML_SYCL_DEBUG (" %s: Mask precision: F32\n " , __func__);
256258 soft_max_f32_sycl<float >(src0_dd, src1_dd, dst_dd, ne00, nrows_x, nrows_y, scale, max_bias, main_stream, ctx.device );
257259 } else {
258260 /* mask unavailable */
259- soft_max_f32_sycl<float >(src0_dd, nullptr , dst_dd, ne00, nrows_x, nrows_y, scale, max_bias, main_stream, ctx.device );
261+ GGML_SYCL_DEBUG (" %s: No mask supplied\n " , __func__);
262+ soft_max_f32_sycl<float >(src0_dd, nullptr , dst_dd, ne00, nrows_x, nrows_y, scale, max_bias, main_stream,
263+ ctx.device );
260264 }
265+ } catch (const sycl::exception & exc) {
266+ std::cerr << exc.what () << " Exception caught at file:" << __FILE__ << " , line:" << __LINE__ << std::endl;
267+ std::exit (1 );
268+ }
269+
270+ void ggml_sycl_softmax (ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
271+ GGML_SYCL_DEBUG (" call %s\n " , __func__);
272+ ggml_sycl_op_soft_max (ctx, dst);
273+ GGML_SYCL_DEBUG (" call %s done\n " , __func__);
261274}
0 commit comments