Skip to content

Commit f5a0dcb

Browse files
author
Gitty Burstein
committed
move SET op to standalone file, GPU-only implementation
1 parent e654008 commit f5a0dcb

File tree

6 files changed

+176
-163
lines changed

6 files changed

+176
-163
lines changed

ggml/src/ggml-sycl/element_wise.cpp

Lines changed: 1 addition & 132 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
#include "ggml-sycl/presets.hpp"
33
#include "ggml.h"
44
#include "element_wise.hpp"
5-
#include <cstring>
5+
66
#define SYCL_GLOBAL_ID_LOOP(K, ITEM) \
77
for (auto i = ITEM.get_global_id(0); i < (size_t)K; i += ITEM.get_global_range(0))
88

@@ -926,132 +926,6 @@ static inline void ggml_sycl_op_pad(ggml_backend_sycl_context & ctx, ggml_tensor
926926
ggml_sycl_detail::pad_sycl(src, dst_ptr, ne00, ne01, ne02, ne0, ne1, ne2, stream);
927927
});
928928
}
929-
static inline void ggml_sycl_op_set(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
930-
const ggml_tensor * src0 = dst->src[0];
931-
GGML_ASSERT(dst->src[1] != nullptr);
932-
const ggml_tensor * src1 = dst->src[1];
933-
934-
GGML_ASSERT(src0->type == dst->type);
935-
GGML_ASSERT(src1->type == dst->type);
936-
#if defined(GGML_SYCL_F16)
937-
GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_I32);
938-
#else
939-
GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_I32);
940-
#endif
941-
const size_t ts = ggml_type_size(dst->type);
942-
dpct::queue_ptr q = ctx.stream();
943-
{
944-
const bool same_type = (src0->type == dst->type);
945-
const bool src_cont = ggml_is_contiguous(src0);
946-
const bool dst_cont = ggml_is_contiguous(dst);
947-
948-
const void *p_src0 = src0->data;
949-
void *p_dst = dst->data;
950-
951-
auto pt_src0 = sycl::get_pointer_type((const char*)p_src0, q->get_context());
952-
auto pt_dst = sycl::get_pointer_type((char*)p_dst, q->get_context());
953-
954-
if (same_type && src_cont && dst_cont && ggml_nelements(src0) == ggml_nelements(dst)) {
955-
const size_t bytes = ggml_nbytes(dst);
956-
if (pt_src0 != sycl::usm::alloc::unknown && pt_dst != sycl::usm::alloc::unknown) {
957-
SYCL_CHECK(CHECK_TRY_ERROR(q->memcpy(p_dst, p_src0, bytes)));
958-
} else {
959-
std::memcpy(p_dst, p_src0, bytes);
960-
}
961-
} else {
962-
const int64_t ne0 = dst->ne[0], ne1 = dst->ne[1], ne2 = dst->ne[2], ne3 = dst->ne[3];
963-
const size_t db0 = dst->nb[0], db1 = dst->nb[1], db2 = dst->nb[2], db3 = dst->nb[3];
964-
const size_t sb0 = src0->nb[0], sb1 = src0->nb[1], sb2 = src0->nb[2], sb3 = src0->nb[3];
965-
966-
const size_t N = (size_t) ggml_nelements(dst);
967-
const size_t WG = 256;
968-
const size_t NG = ((N + WG - 1) / WG) * WG;
969-
970-
const size_t ge0 = (size_t) ne0;
971-
const size_t ge1 = ge0 * (size_t) ne1;
972-
const size_t ge2 = ge1 * (size_t) ne2;
973-
974-
q->parallel_for(
975-
sycl::nd_range<1>(sycl::range<1>(NG), sycl::range<1>(WG)),
976-
[=](sycl::nd_item<1> it) {
977-
size_t idx = it.get_global_linear_id();
978-
if (idx >= N) return;
979-
980-
size_t i3 = idx / ge2; size_t r2 = idx % ge2;
981-
size_t i2 = r2 / ge1; size_t r1 = r2 % ge1;
982-
size_t i1 = r1 / ge0; size_t i0 = r1 % ge0;
983-
984-
const char * s = (const char*)p_src0 + (i0*sb0 + i1*sb1 + i2*sb2 + i3*sb3);
985-
char * d = (char*)p_dst + (i0*db0 + i1*db1 + i2*db2 + i3*db3);
986-
987-
for (size_t b = 0; b < ts; ++b) d[b] = s[b];
988-
}
989-
);
990-
}
991-
}
992-
993-
{
994-
const int32_t *p = (const int32_t *) dst->op_params;
995-
const size_t nb1 = (size_t) p[0];
996-
const size_t nb2 = (size_t) p[1];
997-
const size_t nb3 = (size_t) p[2];
998-
const size_t offset = (size_t) p[3];
999-
1000-
const void *p_src1 = src1->data;
1001-
void *p_dst = dst->data;
1002-
1003-
const size_t sb0 = src1->nb[0], sb1 = src1->nb[1], sb2 = src1->nb[2], sb3 = src1->nb[3];
1004-
const size_t db0 = dst->nb[0];
1005-
const int64_t ne0 = src1->ne[0], ne1 = src1->ne[1], ne2 = src1->ne[2], ne3 = src1->ne[3];
1006-
if (ggml_is_contiguous(src1) && db0 == ts) {
1007-
const size_t row_bytes = (size_t) ne0 * ts;
1008-
const char *s_base = (const char*) p_src1;
1009-
char *d_base = (char*) p_dst + offset;
1010-
1011-
for (int64_t i3 = 0; i3 < ne3; ++i3) {
1012-
for (int64_t i2 = 0; i2 < ne2; ++i2) {
1013-
for (int64_t i1 = 0; i1 < ne1; ++i1) {
1014-
const char *s_row = s_base + i1*sb1 + i2*sb2 + i3*sb3;
1015-
char *d_row = d_base + i1*nb1 + i2*nb2 + i3*nb3;
1016-
1017-
auto pt_s = sycl::get_pointer_type(s_row, q->get_context());
1018-
auto pt_d = sycl::get_pointer_type(d_row, q->get_context());
1019-
if (pt_s != sycl::usm::alloc::unknown && pt_d != sycl::usm::alloc::unknown) {
1020-
SYCL_CHECK(CHECK_TRY_ERROR(q->memcpy(d_row, s_row, row_bytes)));
1021-
} else {
1022-
std::memcpy(d_row, s_row, row_bytes);
1023-
}
1024-
}
1025-
}
1026-
}
1027-
} else {
1028-
1029-
const size_t N = (size_t) (ne0 * ne1 * ne2 * ne3);
1030-
const size_t WG = 256;
1031-
const size_t NG = ((N + WG - 1) / WG) * WG;
1032-
const size_t ge0 = (size_t) ne0;
1033-
const size_t ge1 = ge0 * (size_t) ne1;
1034-
const size_t ge2 = ge1 * (size_t) ne2;
1035-
1036-
q->parallel_for(
1037-
sycl::nd_range<1>(sycl::range<1>(NG), sycl::range<1>(WG)),
1038-
[=](sycl::nd_item<1> it) {
1039-
size_t idx = it.get_global_linear_id();
1040-
if (idx >= N) return;
1041-
1042-
size_t i3 = idx / ge2; size_t r2 = idx % ge2;
1043-
size_t i2 = r2 / ge1; size_t r1 = r2 % ge1;
1044-
size_t i1 = r1 / ge0; size_t i0 = r1 % ge0;
1045-
1046-
const char * s = (const char*) p_src1 + (i0*sb0 + i1*sb1 + i2*sb2 + i3*sb3);
1047-
char * d = (char*) p_dst + offset + (i0*db0 + i1*nb1 + i2*nb2 + i3*nb3);
1048-
1049-
for (size_t b = 0; b < ts; ++b) d[b] = s[b];
1050-
}
1051-
);
1052-
}
1053-
}
1054-
}
1055929

1056930
static inline void ggml_sycl_op_clamp(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1057931
float min_val;
@@ -1250,11 +1124,6 @@ void ggml_sycl_pad(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
12501124
ggml_sycl_op_pad(ctx, dst);
12511125
}
12521126

1253-
void ggml_sycl_set(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1254-
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2);
1255-
ggml_sycl_op_set(ctx, dst);
1256-
}
1257-
12581127
void ggml_sycl_clamp(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
12591128
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
12601129
ggml_sycl_op_clamp(ctx, dst);

ggml/src/ggml-sycl/element_wise.hpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,4 @@ void ggml_sycl_swiglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
8383
void ggml_sycl_geglu_erf(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
8484
void ggml_sycl_geglu_quick(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
8585

86-
void ggml_sycl_set(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
87-
8886
#endif // GGML_SYCL_ELEMENTWISE_HPP

ggml/src/ggml-sycl/ggml-sycl.cpp

Lines changed: 14 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
#include "ggml-sycl/presets.hpp"
4343
#include "ggml-sycl/gemm.hpp"
4444
#include "ggml-sycl/set_rows.hpp"
45+
#include "ggml-sycl/set.hpp"
4546
#include "ggml-sycl/sycl_hw.hpp"
4647
#include "ggml-sycl/getrows.hpp"
4748
#include "ggml-sycl/quantize.hpp"
@@ -3565,7 +3566,7 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
35653566
ggml_sycl_get_rows(ctx, dst);
35663567
break;
35673568
case GGML_OP_SET:
3568-
ggml_sycl_set(ctx, dst);
3569+
ggml_sycl_op_set(ctx, dst);
35693570
break;
35703571
case GGML_OP_SET_ROWS:
35713572
ggml_sycl_op_set_rows(ctx, dst);
@@ -4170,34 +4171,6 @@ static ggml_backend_buffer_t ggml_backend_sycl_device_buffer_from_host_ptr(ggml_
41704171

41714172
static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
41724173
switch (op->op) {
4173-
case GGML_OP_SET: {
4174-
#if defined(GGML_SYCL_F16)
4175-
const bool types_ok =
4176-
(op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16 || op->type == GGML_TYPE_I32) &&
4177-
(op->src[0]->type == op->type) &&
4178-
(op->src[1] && op->src[1]->type == op->type);
4179-
#else
4180-
const bool types_ok =
4181-
(op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_I32) &&
4182-
(op->src[0]->type == op->type) &&
4183-
(op->src[1] && op->src[1]->type == op->type);
4184-
#endif
4185-
4186-
const bool contiguous_ok =
4187-
ggml_is_contiguous(op->src[0]) &&
4188-
(!op->src[1] || ggml_is_contiguous(op->src[1]));
4189-
4190-
return types_ok && contiguous_ok;
4191-
}
4192-
case GGML_OP_CONV_TRANSPOSE_1D:
4193-
{
4194-
ggml_type src0_type = op->src[0]->type;
4195-
ggml_type src1_type = op->src[1]->type;
4196-
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) {
4197-
return true;
4198-
}
4199-
return false;
4200-
}
42014174
case GGML_OP_UNARY:
42024175
switch (ggml_get_unary_op(op)) {
42034176
case GGML_UNARY_OP_NEG:
@@ -4288,6 +4261,18 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
42884261
return false;
42894262
}
42904263
}
4264+
case GGML_OP_SET:
4265+
#if defined(GGML_SYCL_F16)
4266+
return (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16 || op->type == GGML_TYPE_I32) &&
4267+
(op->src[0] && op->src[1]) &&
4268+
(op->src[0]->type == op->type) &&
4269+
(op->src[1]->type == op->type);
4270+
#else
4271+
return (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_I32) &&
4272+
(op->src[0] && op->src[1]) &&
4273+
(op->src[0]->type == op->type) &&
4274+
(op->src[1]->type == op->type);
4275+
#endif
42914276
case GGML_OP_SET_ROWS:
42924277
{
42934278
return ((op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16 || op->type == GGML_TYPE_BF16 ||

ggml/src/ggml-sycl/presets.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
#define SYCL_SQRT_BLOCK_SIZE 256
3232
#define SYCL_SIN_BLOCK_SIZE 256
3333
#define SYCL_SQR_BLOCK_SIZE 256
34+
#define SYCL_SET_BLOCK_SIZE 256
3435
#define SYCL_CPY_BLOCK_SIZE 32
3536
#define SYCL_SCALE_BLOCK_SIZE 256
3637
#define SYCL_CLAMP_BLOCK_SIZE 256

ggml/src/ggml-sycl/set.cpp

Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
// ggml/src/ggml-sycl/set.cpp
2+
//
3+
// SYCL backend for GGML SET operator.
4+
//
5+
// Semantics:
6+
// 1) dst <- src0
7+
// 2) copy a sub-block from src1 into dst at byte `offset`,
8+
// using destination byte-strides (nb1, nb2, nb3) for dims 1..3.
9+
//
10+
// Notes:
11+
// - (nb1, nb2, nb3, offset) are BYTES (CPU-compatible).
12+
// - Uses two fast paths (bulk memcpy; row-wise memcpy) and a generic 4D kernel.
13+
// - Work-group size is configured in presets (SYCL_SET_BLOCK_SIZE).
14+
//
15+
// Implementation style aligned with other SYCL operators:
16+
// - No host std::memcpy fallback; no USM detection.
17+
// - Copies use queue->memcpy; generic case uses a parallel_for kernel.
18+
19+
#include "presets.hpp" // SYCL_* tuning (incl. SYCL_SET_BLOCK_SIZE)
20+
#include "common.hpp"
21+
#include "ggml.h"
22+
#include "set.hpp"
23+
24+
#include <cstdint>
25+
#include <cstring>
26+
27+
// ---------------- helpers (file-local) ----------------
28+
29+
// Byte-accurate 4D copy with independent src/dst byte strides.
30+
// One work-item copies exactly one element (ts bytes).
31+
static inline void launch_copy_4d_bytes(
32+
dpct::queue_ptr q,
33+
const void *p_src, void *p_dst,
34+
const int64_t ne[4],
35+
const size_t sb[4],
36+
const size_t db[4],
37+
const size_t ts
38+
) {
39+
const size_t N = (size_t)(ne[0] * ne[1] * ne[2] * ne[3]);
40+
if (N == 0) return;
41+
42+
const size_t WG = (size_t)SYCL_SET_BLOCK_SIZE;
43+
const size_t NG = ((N + WG - 1) / WG) * WG;
44+
45+
const size_t ge0 = (size_t) ne[0];
46+
const size_t ge1 = ge0 * (size_t) ne[1];
47+
const size_t ge2 = ge1 * (size_t) ne[2];
48+
49+
q->parallel_for(
50+
sycl::nd_range<1>(sycl::range<1>(NG), sycl::range<1>(WG)),
51+
[=](sycl::nd_item<1> it) {
52+
size_t idx = it.get_global_linear_id();
53+
if (idx >= N) return;
54+
55+
// 4D indexing
56+
size_t i3 = idx / ge2; size_t r2 = idx % ge2;
57+
size_t i2 = r2 / ge1; size_t r1 = r2 % ge1;
58+
size_t i1 = r1 / ge0; size_t i0 = r1 % ge0;
59+
60+
const char *s = (const char *)p_src + (i0*sb[0] + i1*sb[1] + i2*sb[2] + i3*sb[3]);
61+
char *d = (char *)p_dst + (i0*db[0] + i1*db[1] + i2*db[2] + i3*db[3]);
62+
63+
#pragma unroll
64+
for (size_t b = 0; b < ts; ++b) {
65+
d[b] = s[b];
66+
}
67+
}
68+
);
69+
}
70+
71+
// --------------------------- operator ---------------------------
72+
73+
void ggml_sycl_op_set(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
74+
GGML_ASSERT(dst != nullptr);
75+
const ggml_tensor * src0 = dst->src[0];
76+
GGML_ASSERT(dst->src[1] != nullptr);
77+
const ggml_tensor * src1 = dst->src[1];
78+
GGML_ASSERT(src0 && src1);
79+
80+
// Type constraints (CPU-compatible)
81+
GGML_ASSERT(src0->type == dst->type);
82+
GGML_ASSERT(src1->type == dst->type);
83+
#if defined(GGML_SYCL_F16)
84+
GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_I32);
85+
#else
86+
GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_I32);
87+
#endif
88+
89+
dpct::queue_ptr q = ctx.stream();
90+
const size_t ts = ggml_type_size(dst->type);
91+
92+
// Stage 1: dst <- src0
93+
{
94+
const bool same_type = (src0->type == dst->type);
95+
const bool src_cont = ggml_is_contiguous(src0);
96+
const bool dst_cont = ggml_is_contiguous(dst);
97+
98+
const void *p_src0 = src0->data;
99+
void *p_dst = dst->data;
100+
101+
if (same_type && src_cont && dst_cont &&
102+
ggml_nelements(src0) == ggml_nelements(dst)) {
103+
SYCL_CHECK(CHECK_TRY_ERROR(q->memcpy(p_dst, p_src0, ggml_nbytes(dst))));
104+
} else {
105+
// generic 4D copy
106+
const int64_t ne[4] = { dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3] };
107+
const size_t sb[4] = { src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3] };
108+
const size_t db[4] = { dst ->nb[0], dst ->nb[1], dst ->nb[2], dst ->nb[3] };
109+
launch_copy_4d_bytes(q, p_src0, p_dst, ne, sb, db, ts);
110+
}
111+
}
112+
113+
// Stage 2: paste src1 sub-block into dst
114+
{
115+
// op_params: [ nb1, nb2, nb3, offset ] (BYTES)
116+
const int32_t *p = (const int32_t *) dst->op_params;
117+
const size_t nb1 = (size_t) p[0];
118+
const size_t nb2 = (size_t) p[1];
119+
const size_t nb3 = (size_t) p[2];
120+
const size_t offset = (size_t) p[3];
121+
122+
const void *p_src1 = src1->data;
123+
void *p_base = (char *) dst->data + offset;
124+
125+
const bool src1_cont = ggml_is_contiguous(src1);
126+
const bool dst_tight = (dst->nb[0] == ts); // tightly-packed rows
127+
128+
if (src1_cont && dst_tight) {
129+
// Row-wise device memcpy of src1 into dst at the given offset
130+
const char *s_base = (const char *) p_src1;
131+
char *d_base = (char *) p_base;
132+
const size_t row_bytes = (size_t) src1->ne[0] * ts;
133+
134+
const size_t sb1 = src1->nb[1];
135+
const size_t sb2 = src1->nb[2];
136+
const size_t sb3 = src1->nb[3];
137+
138+
for (int64_t i3 = 0; i3 < src1->ne[3]; ++i3) {
139+
for (int64_t i2 = 0; i2 < src1->ne[2]; ++i2) {
140+
for (int64_t i1 = 0; i1 < src1->ne[1]; ++i1) {
141+
const char *s_row = s_base + i1*sb1 + i2*sb2 + i3*sb3;
142+
char *d_row = d_base + i1*nb1 + i2*nb2 + i3*nb3;
143+
SYCL_CHECK(CHECK_TRY_ERROR(q->memcpy(d_row, s_row, row_bytes)));
144+
}
145+
}
146+
}
147+
} else {
148+
// Generic 4D copy from src1 into (offsetted) dst base
149+
const int64_t ne[4] = { src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3] };
150+
const size_t sb[4] = { src1->nb[0], src1->nb[1], src1->nb[2], src1->nb[3] };
151+
const size_t db[4] = { dst->nb[0], nb1, nb2, nb3 };
152+
launch_copy_4d_bytes(q, p_src1, p_base, ne, sb, db, ts);
153+
}
154+
}
155+
}

0 commit comments

Comments
 (0)