Skip to content

Commit b907255

Browse files
authored
SYCL: Add COUNT_EQUAL operator support (ggml-org#15991)
* SYCL: Add COUNT_EQUAL operator support (rebased on master) * SYCL: remove duplicate op_count_equal definition * tests: remove test_count_equal_typed and use test_count_equal for all cases * tests: keep only I32 case for COUNT_EQUAL as suggested * tests: keep only I32 case for COUNT_EQUAL as requested
1 parent 28c39da commit b907255

File tree

3 files changed

+19
-0
lines changed

3 files changed

+19
-0
lines changed

ggml/src/ggml-sycl/binbcast.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -303,6 +303,10 @@ inline void ggml_sycl_op_sub(ggml_backend_sycl_context & ctx, ggml_tensor *dst)
303303
ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_sub>>(ctx, dst->src[0], dst->src[1], dst);
304304
}
305305

306+
inline void ggml_sycl_op_count_equal(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
307+
ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_count_equal>>(ctx, dst->src[0], dst->src[1], dst);
308+
}
309+
306310
inline void ggml_sycl_op_mul(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
307311

308312
ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_mul>>(ctx, dst->src[0], dst->src[1], dst);
@@ -328,6 +332,11 @@ void ggml_sycl_sub(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
328332
ggml_sycl_op_sub(ctx, dst);
329333
}
330334

335+
void ggml_sycl_count_equal(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
336+
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2);
337+
ggml_sycl_op_count_equal(ctx, dst);
338+
}
339+
331340
void ggml_sycl_mul(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
332341
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2);
333342
ggml_sycl_op_mul(ctx, dst);

ggml/src/ggml-sycl/binbcast.hpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,12 @@ static __dpct_inline__ float op_sub(const float a, const float b) {
1616
return a - b;
1717
}
1818

19+
static __dpct_inline__ float op_count_equal(const float a, const float b) {
20+
return (a == b) ? 1.0f : 0.0f;
21+
}
22+
23+
void ggml_sycl_count_equal(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
24+
1925
static __dpct_inline__ float op_mul(const float a, const float b) {
2026
return a * b;
2127
}

ggml/src/ggml-sycl/ggml-sycl.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3577,6 +3577,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
35773577
case GGML_OP_SUB:
35783578
ggml_sycl_sub(ctx, dst);
35793579
break;
3580+
case GGML_OP_COUNT_EQUAL:
3581+
ggml_sycl_count_equal(ctx, dst);
3582+
break;
35803583
case GGML_OP_ACC:
35813584
ggml_sycl_acc(ctx, dst);
35823585
break;
@@ -4356,6 +4359,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
43564359
case GGML_OP_ADD:
43574360
case GGML_OP_ADD1:
43584361
case GGML_OP_SUB:
4362+
case GGML_OP_COUNT_EQUAL:
43594363
case GGML_OP_MUL:
43604364
case GGML_OP_DIV:
43614365
case GGML_OP_REPEAT:

0 commit comments

Comments
 (0)