Skip to content

Commit ee50ee1

Browse files
yael-worksCISC
andauthored
SYCL: Add GGML_OP_MEAN operator support (ggml-org#16009)
* SYCL: Add GGML_OP_MEAN operator support * SYCL: Fix formatting for GGML_OP_MEAN case * Update ggml/src/ggml-sycl/ggml-sycl.cpp Co-authored-by: Sigbjørn Skjæret <[email protected]> --------- Co-authored-by: Sigbjørn Skjæret <[email protected]>
1 parent 7adc79c commit ee50ee1

File tree

1 file changed

+34
-0
lines changed

1 file changed

+34
-0
lines changed

ggml/src/ggml-sycl/ggml-sycl.cpp

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2151,6 +2151,30 @@ inline void ggml_sycl_op_sum_rows(ggml_backend_sycl_context & ctx, ggml_tensor *
21512151
sum_rows_f32_sycl(src0_dd, dst_dd, ncols, nrows, main_stream);
21522152
}
21532153

2154+
inline void ggml_sycl_op_mean(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
2155+
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
2156+
GGML_ASSERT(dst->type == GGML_TYPE_F32);
2157+
2158+
dpct::queue_ptr main_stream = ctx.stream();
2159+
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
2160+
2161+
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
2162+
float * dst_dd = static_cast<float *>(dst->data);
2163+
2164+
const int64_t ncols = dst->src[0]->ne[0];
2165+
const int64_t nrows = ggml_nrows(dst->src[0]);
2166+
2167+
sum_rows_f32_sycl(src0_dd, dst_dd, ncols, nrows, main_stream);
2168+
2169+
main_stream->parallel_for(
2170+
sycl::range<1>(nrows),
2171+
[=](sycl::id<1> row) {
2172+
dst_dd[row] /= ncols;
2173+
}
2174+
);
2175+
}
2176+
2177+
21542178
inline void ggml_sycl_op_argsort(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
21552179
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
21562180
GGML_ASSERT(dst->type == GGML_TYPE_I32);
@@ -3535,6 +3559,12 @@ static void ggml_sycl_sum_rows(ggml_backend_sycl_context & ctx, ggml_tensor * ds
35353559
ggml_sycl_op_sum_rows(ctx, dst);
35363560
}
35373561

3562+
static void ggml_sycl_mean(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
3563+
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
3564+
GGML_ASSERT(ggml_is_contiguous(dst->src[0]));
3565+
ggml_sycl_op_mean(ctx, dst);
3566+
}
3567+
35383568
static void ggml_sycl_argsort(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
35393569
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
35403570
GGML_ASSERT(ggml_is_contiguous(dst->src[0]));
@@ -3784,6 +3814,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
37843814
case GGML_OP_SUM_ROWS:
37853815
ggml_sycl_sum_rows(ctx, dst);
37863816
break;
3817+
case GGML_OP_MEAN:
3818+
ggml_sycl_mean(ctx, dst);
3819+
break;
37873820
case GGML_OP_ARGSORT:
37883821
ggml_sycl_argsort(ctx, dst);
37893822
break;
@@ -4431,6 +4464,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
44314464
return op->src[0]->type == GGML_TYPE_F32 && op->op_params[0] == GGML_SCALE_MODE_NEAREST;
44324465
case GGML_OP_SUM:
44334466
case GGML_OP_SUM_ROWS:
4467+
case GGML_OP_MEAN:
44344468
case GGML_OP_ARGSORT:
44354469
return ggml_is_contiguous(op->src[0]);
44364470
case GGML_OP_POOL_2D:

0 commit comments

Comments
 (0)