@@ -2151,6 +2151,30 @@ inline void ggml_sycl_op_sum_rows(ggml_backend_sycl_context & ctx, ggml_tensor *
21512151 sum_rows_f32_sycl (src0_dd, dst_dd, ncols, nrows, main_stream);
21522152}
21532153
2154+ inline void ggml_sycl_op_mean (ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
2155+ GGML_ASSERT (dst->src [0 ]->type == GGML_TYPE_F32);
2156+ GGML_ASSERT (dst->type == GGML_TYPE_F32);
2157+
2158+ dpct::queue_ptr main_stream = ctx.stream ();
2159+ SYCL_CHECK (ggml_sycl_set_device (ctx.device ));
2160+
2161+ const float * src0_dd = static_cast <const float *>(dst->src [0 ]->data );
2162+ float * dst_dd = static_cast <float *>(dst->data );
2163+
2164+ const int64_t ncols = dst->src [0 ]->ne [0 ];
2165+ const int64_t nrows = ggml_nrows (dst->src [0 ]);
2166+
2167+ sum_rows_f32_sycl (src0_dd, dst_dd, ncols, nrows, main_stream);
2168+
2169+ main_stream->parallel_for (
2170+ sycl::range<1 >(nrows),
2171+ [=](sycl::id<1 > row) {
2172+ dst_dd[row] /= ncols;
2173+ }
2174+ );
2175+ }
2176+
2177+
21542178inline void ggml_sycl_op_argsort (ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
21552179 GGML_ASSERT (dst->src [0 ]->type == GGML_TYPE_F32);
21562180 GGML_ASSERT (dst->type == GGML_TYPE_I32);
@@ -3535,6 +3559,12 @@ static void ggml_sycl_sum_rows(ggml_backend_sycl_context & ctx, ggml_tensor * ds
35353559 ggml_sycl_op_sum_rows (ctx, dst);
35363560}
35373561
3562+ static void ggml_sycl_mean (ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
3563+ scope_op_debug_print scope_dbg_print (__func__, dst, /* num_src=*/ 1 );
3564+ GGML_ASSERT (ggml_is_contiguous (dst->src [0 ]));
3565+ ggml_sycl_op_mean (ctx, dst);
3566+ }
3567+
35383568static void ggml_sycl_argsort (ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
35393569 scope_op_debug_print scope_dbg_print (__func__, dst, /* num_src=*/ 1 );
35403570 GGML_ASSERT (ggml_is_contiguous (dst->src [0 ]));
@@ -3784,6 +3814,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
37843814 case GGML_OP_SUM_ROWS:
37853815 ggml_sycl_sum_rows (ctx, dst);
37863816 break ;
3817+ case GGML_OP_MEAN:
3818+ ggml_sycl_mean (ctx, dst);
3819+ break ;
37873820 case GGML_OP_ARGSORT:
37883821 ggml_sycl_argsort (ctx, dst);
37893822 break ;
@@ -4431,6 +4464,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
44314464 return op->src [0 ]->type == GGML_TYPE_F32 && op->op_params [0 ] == GGML_SCALE_MODE_NEAREST;
44324465 case GGML_OP_SUM:
44334466 case GGML_OP_SUM_ROWS:
4467+ case GGML_OP_MEAN:
44344468 case GGML_OP_ARGSORT:
44354469 return ggml_is_contiguous (op->src [0 ]);
44364470 case GGML_OP_POOL_2D:
0 commit comments