Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions ggml/src/ggml-sycl/binbcast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,10 @@ inline void ggml_sycl_op_sub(ggml_backend_sycl_context & ctx, ggml_tensor *dst)
ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_sub>>(ctx, dst->src[0], dst->src[1], dst);
}

inline void ggml_sycl_op_count_equal(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_count_equal>>(ctx, dst->src[0], dst->src[1], dst);
}

inline void ggml_sycl_op_mul(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {

ggml_sycl_op_bin_bcast<bin_bcast_sycl<op_mul>>(ctx, dst->src[0], dst->src[1], dst);
Expand All @@ -328,6 +332,11 @@ void ggml_sycl_sub(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
ggml_sycl_op_sub(ctx, dst);
}

void ggml_sycl_count_equal(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2);
ggml_sycl_op_count_equal(ctx, dst);
}

void ggml_sycl_mul(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2);
ggml_sycl_op_mul(ctx, dst);
Expand Down
6 changes: 6 additions & 0 deletions ggml/src/ggml-sycl/binbcast.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,12 @@ static __dpct_inline__ float op_sub(const float a, const float b) {
return a - b;
}

static __dpct_inline__ float op_count_equal(const float a, const float b) {
return (a == b) ? 1.0f : 0.0f;
}

void ggml_sycl_count_equal(ggml_backend_sycl_context & ctx, ggml_tensor * dst);

static __dpct_inline__ float op_mul(const float a, const float b) {
return a * b;
}
Expand Down
4 changes: 4 additions & 0 deletions ggml/src/ggml-sycl/ggml-sycl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3577,6 +3577,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
case GGML_OP_SUB:
ggml_sycl_sub(ctx, dst);
break;
case GGML_OP_COUNT_EQUAL:
ggml_sycl_count_equal(ctx, dst);
break;
case GGML_OP_ACC:
ggml_sycl_acc(ctx, dst);
break;
Expand Down Expand Up @@ -4356,6 +4359,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_OP_ADD:
case GGML_OP_ADD1:
case GGML_OP_SUB:
case GGML_OP_COUNT_EQUAL:
case GGML_OP_MUL:
case GGML_OP_DIV:
case GGML_OP_REPEAT:
Expand Down
30 changes: 30 additions & 0 deletions tests/test-backend-ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2236,6 +2236,30 @@ struct test_count_equal : public test_case {
}
};

/* COUNT_EQUAL – typed test (no argmax), to cover F32/F16/I32/I16 */
struct test_count_equal_typed : public test_case {
const ggml_type type;
const std::array<int64_t, 4> ne;

test_count_equal_typed(ggml_type type = GGML_TYPE_F32,
std::array<int64_t, 4> ne = {128, 64, 1, 1})
: type(type), ne(ne) {}

std::string vars() override {
return VARS_TO_STR2(type, ne);
}

ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
ggml_set_name(a, "a");
ggml_tensor * b = ggml_new_tensor(ctx, type, 4, ne.data());
ggml_set_name(b, "b");
ggml_tensor * out = ggml_count_equal(ctx, a, b);
ggml_set_name(out, "out");
return out;
}
};

// GGML_OP_REPEAT
struct test_repeat : public test_case {
const ggml_type type;
Expand Down Expand Up @@ -5940,6 +5964,12 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {

test_cases.emplace_back(new test_count_equal(GGML_TYPE_F32, {4, 500, 1, 1}));
test_cases.emplace_back(new test_count_equal(GGML_TYPE_F32, {4, 5000, 1, 1}));
// COUNT_EQUAL – typed tests by dtype
test_cases.emplace_back(new test_count_equal_typed(GGML_TYPE_F32, {1024, 1, 1, 1}));
test_cases.emplace_back(new test_count_equal_typed(GGML_TYPE_F32, { 64, 64, 1, 1}));
test_cases.emplace_back(new test_count_equal_typed(GGML_TYPE_F16, { 256, 32, 1, 1}));
test_cases.emplace_back(new test_count_equal_typed(GGML_TYPE_I32, { 512, 16, 1, 1}));
test_cases.emplace_back(new test_count_equal_typed(GGML_TYPE_I16, { 512, 16, 1, 1}));

test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {32, 1, 1, 1}));
test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {32, 513, 1, 1}));
Expand Down
Loading