Skip to content

Commit 7ce4ecf

Browse files
committed
sycl: add CONCAT operator support
1 parent ee09828 commit 7ce4ecf

File tree

3 files changed

+73
-75
lines changed

3 files changed

+73
-75
lines changed

ggml/src/ggml-sycl/concat.cpp

Lines changed: 69 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,11 @@
1313
#include "concat.hpp"
1414
#include "common.hpp"
1515

16-
static void concat_f32_dim0(const float *x, const float *y, float *dst,
16+
static inline size_t elem_size(ggml_type t) {
17+
return ggml_type_size(t) / ggml_blck_size(t);
18+
}
19+
template <typename T>
20+
static void concat_T_dim0(const T *x, const T *y, T *dst,
1721
const int ne0, const int ne00,
1822
const sycl::nd_item<3> &item_ct1) {
1923
int nidx = item_ct1.get_local_id(2) +
@@ -36,7 +40,8 @@ static void concat_f32_dim0(const float *x, const float *y, float *dst,
3640
}
3741
}
3842

39-
static void concat_f32_dim1(const float *x, const float *y, float *dst,
43+
template <typename T>
44+
static void concat_T_dim1(const T *x, const T *y, T *dst,
4045
const int ne0, const int ne01,
4146
const sycl::nd_item<3> &item_ct1) {
4247
int nidx = item_ct1.get_local_id(2) +
@@ -59,7 +64,8 @@ static void concat_f32_dim1(const float *x, const float *y, float *dst,
5964
}
6065
}
6166

62-
static void concat_f32_dim2(const float *x, const float *y, float *dst,
67+
template <typename T>
68+
static void concat_T_dim2(const T *x, const T *y, T *dst,
6369
const int ne0, const int ne02,
6470
const sycl::nd_item<3> &item_ct1) {
6571
int nidx = item_ct1.get_local_id(2) +
@@ -82,45 +88,38 @@ static void concat_f32_dim2(const float *x, const float *y, float *dst,
8288
}
8389
}
8490

85-
static void concat_f32_sycl(const float *x, const float *y, float *dst,
91+
template <typename T>
92+
static void concat_T_sycl(const T *x, const T *y, T *dst,
8693
int ne00, int ne01, int ne02, int ne0, int ne1,
8794
int ne2, int dim, queue_ptr stream) {
8895
int num_blocks = (ne0 + SYCL_CONCAT_BLOCK_SIZE - 1) / SYCL_CONCAT_BLOCK_SIZE;
8996
sycl::range<3> gridDim(ne2, ne1, num_blocks);
9097
switch (dim) {
9198
case 0:
92-
stream->parallel_for(
93-
sycl::nd_range<3>(gridDim *
94-
sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE),
95-
sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE)),
96-
[=](sycl::nd_item<3> item_ct1) {
97-
concat_f32_dim0(x, y, dst, ne0, ne00, item_ct1);
98-
});
99-
break;
99+
sycl_parallel_for(stream,
100+
sycl::nd_range<3>(gridDim * sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE),
101+
sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE)),
102+
[=](sycl::nd_item<3> item_ct1) { concat_T_dim0<T>(x, y, dst, ne0, ne00, item_ct1); });
103+
break;
100104
case 1:
101-
stream->parallel_for(
102-
sycl::nd_range<3>(gridDim *
103-
sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE),
104-
sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE)),
105-
[=](sycl::nd_item<3> item_ct1) {
106-
concat_f32_dim1(x, y, dst, ne0, ne01, item_ct1);
107-
});
108-
break;
105+
sycl_parallel_for(stream,
106+
sycl::nd_range<3>(gridDim * sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE),
107+
sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE)),
108+
[=](sycl::nd_item<3> item_ct1) { concat_T_dim1<T>(x, y, dst, ne0, ne01, item_ct1); });
109+
break;
109110
// dim >=2 will be dispatched to the default path
110111
default:
111-
stream->parallel_for(
112-
sycl::nd_range<3>(gridDim *
113-
sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE),
114-
sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE)),
115-
[=](sycl::nd_item<3> item_ct1) {
116-
concat_f32_dim2(x, y, dst, ne0, ne02, item_ct1);
117-
});
118-
break;
112+
sycl_parallel_for(stream,
113+
sycl::nd_range<3>(gridDim * sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE),
114+
sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE)),
115+
[=](sycl::nd_item<3> item_ct1) { concat_T_dim2<T>(x, y, dst, ne0, ne02, item_ct1); });
116+
break;
119117
}
120118
}
121119

122120
// non-contiguous kernel (slow)
123-
static void concat_f32_sycl_non_cont(
121+
template<typename T>
122+
static void concat_T_sycl_non_cont(
124123
queue_ptr stream, const char *src0, const char *src1, char *dst,
125124
int64_t ne00, int64_t ne01, int64_t ne02, int64_t ne03, uint64_t nb00,
126125
uint64_t nb01, uint64_t nb02, uint64_t nb03, int64_t /*ne10*/,
@@ -129,32 +128,33 @@ static void concat_f32_sycl_non_cont(
129128
int64_t ne2, int64_t ne3, uint64_t nb0, uint64_t nb1, uint64_t nb2,
130129
uint64_t nb3, int32_t dim) {
131130
sycl::range<3> gridDim(ne3, ne2, ne1);
132-
stream->parallel_for(sycl::nd_range<3>(gridDim, sycl::range<3>(1, 1, 1)), [=](sycl::nd_item<3> item_ct1) {
131+
sycl_parallel_for(stream, sycl::nd_range<3>(gridDim, sycl::range<3>(1, 1, 1)), [=](sycl::nd_item<3> item_ct1) {
133132
int64_t i3 = item_ct1.get_group(0);
134133
int64_t i2 = item_ct1.get_group(1);
135134
int64_t i1 = item_ct1.get_group(2);
136135

137136
int64_t o[4] = { 0, 0, 0, 0 };
138137
o[dim] = dim == 0 ? ne00 : (dim == 1 ? ne01 : (dim == 2 ? ne02 : ne03));
139138

140-
const float * x;
139+
const T * x;
141140

142141
for (int i0 = item_ct1.get_local_id(2); i0 < ne0; i0 += item_ct1.get_local_range(2)) {
143142
if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
144-
x = (const float *) (src0 + (i3) *nb03 + (i2) *nb02 + (i1) *nb01 + (i0) *nb00);
143+
x = (const T *) (src0 + (i3) *nb03 + (i2) *nb02 + (i1) *nb01 + (i0) *nb00);
145144
} else {
146-
x = (const float *) (src1 + (i3 - o[3]) * nb13 + (i2 - o[2]) * nb12 + (i1 - o[1]) * nb11 +
145+
x = (const T *) (src1 + (i3 - o[3]) * nb13 + (i2 - o[2]) * nb12 + (i1 - o[1]) * nb11 +
147146
(i0 - o[0]) * nb10);
148147
}
149148

150-
float *y = (float *)(dst + i3 * nb3 + i2 * nb2 + i1 * nb1 + i0 * nb0);
149+
T *y = (T *)(dst + i3 * nb3 + i2 * nb2 + i1 * nb1 + i0 * nb0);
151150

152151
*y = *x;
153152
}
154153
});
155154
}
156155

157-
void ggml_sycl_op_concat(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
156+
template <typename T>
157+
void concat_impl_sycl(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
158158
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2);
159159
const ggml_tensor * src0 = dst->src[0];
160160
const ggml_tensor * src1 = dst->src[1];
@@ -163,29 +163,55 @@ void ggml_sycl_op_concat(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
163163
const int32_t dim = ((int32_t *) dst->op_params)[0];
164164

165165
if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) {
166-
const float * src0_d = (const float *) src0->data;
167-
const float * src1_d = (const float *) src1->data;
168-
169-
float * dst_d = (float *) dst->data;
166+
const T * src0_d = (const T *) src0->data;
167+
const T * src1_d = (const T *) src1->data;
170168

169+
T * dst_d = (T *) dst->data;
170+
171+
size_t type_size = elem_size(dst->type);
172+
171173
if (dim != 3) {
172174
for (int i3 = 0; i3 < dst->ne[3]; i3++) {
173-
concat_f32_sycl(src0_d + i3 * (src0->nb[3] / 4), src1_d + i3 * (src1->nb[3] / 4),
174-
dst_d + i3 * (dst->nb[3] / 4), src0->ne[0], src0->ne[1], src0->ne[2], dst->ne[0],
175+
concat_T_sycl<T>(src0_d + i3 * (src0->nb[3] / type_size), src1_d + i3 * (src1->nb[3] / type_size),
176+
dst_d + i3 * (dst->nb[3] / type_size), src0->ne[0], src0->ne[1], src0->ne[2], dst->ne[0],
175177
dst->ne[1], dst->ne[2], dim, stream);
176178
}
177179
} else {
178180
const size_t size0 = ggml_nbytes(src0);
179181
const size_t size1 = ggml_nbytes(src1);
180182

181183
SYCL_CHECK(CHECK_TRY_ERROR(stream->memcpy(dst_d, src0_d, size0).wait()));
182-
SYCL_CHECK(CHECK_TRY_ERROR(stream->memcpy(dst_d + size0 / 4, src1_d, size1).wait()));
184+
SYCL_CHECK(CHECK_TRY_ERROR(stream->memcpy(dst_d + size0 / type_size, src1_d, size1).wait()));
183185
}
184186
} else {
185-
concat_f32_sycl_non_cont(stream, (const char *) src0->data, (const char *) src1->data, (char *) dst->data,
187+
concat_T_sycl_non_cont<T>(stream, (const char *) src0->data, (const char *) src1->data, (char *) dst->data,
186188
src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src0->nb[0], src0->nb[1],
187189
src0->nb[2], src0->nb[3], src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3],
188190
src1->nb[0], src1->nb[1], src1->nb[2], src1->nb[3], dst->ne[0], dst->ne[1], dst->ne[2],
189191
dst->ne[3], dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3], dim);
190192
}
191193
}
194+
195+
void ggml_sycl_op_concat(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
196+
197+
static std::atomic<bool> printed{false};
198+
if (!printed.exchange(true)) std::fprintf(stderr, "[LP] hit ggml_sycl_op_concat\n");
199+
200+
LP_PROFILE_INIT_ONCE();
201+
LP_PROFILE_PAIR(dst->src[0], dst->src[1]);
202+
203+
switch (dst->type) {
204+
case GGML_TYPE_F32:
205+
concat_impl_sycl<float>(ctx, dst);
206+
break;
207+
case GGML_TYPE_I16:
208+
concat_impl_sycl<int16_t>(ctx, dst);
209+
break;
210+
case GGML_TYPE_I32:
211+
concat_impl_sycl<int32_t>(ctx, dst);
212+
break;
213+
default:
214+
GGML_ASSERT(false && "ggml_sycl_op_concat: unsupported type");
215+
break;
216+
}
217+
}

ggml/src/ggml-sycl/ggml-sycl.cpp

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4419,10 +4419,6 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
44194419
return false;
44204420
}
44214421
case GGML_OP_CONCAT:
4422-
{
4423-
ggml_type src0_type = op->src[0]->type;
4424-
return src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16;
4425-
}
44264422
case GGML_OP_DUP:
44274423
case GGML_OP_ARGMAX:
44284424
case GGML_OP_NONE:

tests/test-backend-ops.cpp

Lines changed: 4 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -2557,13 +2557,6 @@ struct test_cpy : public test_case {
25572557

25582558
return out;
25592559
}
2560-
2561-
void initialize_tensors(ggml_context * ctx) override {
2562-
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
2563-
// test extended range of values to check if casting between f32 and i32 is consistent
2564-
init_tensor_uniform(t, -150.f, 150.f);
2565-
}
2566-
}
25672560
};
25682561

25692562
// GGML_OP_CONT
@@ -6217,10 +6210,6 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
62176210
test_cases.emplace_back(new test_cpy(type_src, type_dst, {256, 2, 3, 4}, {1, 0, 2, 3})); // cpy not-contiguous
62186211
}
62196212
}
6220-
test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_I32, {256, 2, 3, 4}));
6221-
test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_I32, {256, 2, 3, 4}, {1, 0, 2, 3}));
6222-
test_cases.emplace_back(new test_cpy(GGML_TYPE_I32, GGML_TYPE_F32, {256, 2, 3, 4}));
6223-
test_cases.emplace_back(new test_cpy(GGML_TYPE_I32, GGML_TYPE_F32, {256, 2, 3, 4}, {1, 0, 2, 3}));
62246213

62256214
test_cases.emplace_back(new test_cont());
62266215
test_cases.emplace_back(new test_cont(GGML_TYPE_F32, {2, 1, 1 ,1}));
@@ -6253,9 +6242,6 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
62536242
add_test_bin_bcast(type, {10, 5, 4, 3}, {1, 2, 2, 2});
62546243
add_test_bin_bcast(type, {10, 5, 4, 3}, {2, 2, 2, 2});
62556244

6256-
// test case for k_bin_bcast_unravel in CUDA backend
6257-
add_test_bin_bcast(type, {1, 1, 65536, 1}, {256, 1, 1, 1});
6258-
62596245
// stable diffusion
62606246
add_test_bin_bcast(type, {1280, 1, 1, 1}, {1, 1, 1, 1});
62616247
add_test_bin_bcast(type, {1280, 1, 1, 1}, {1, 16, 16, 1});
@@ -6658,7 +6644,6 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
66586644
for (int64_t ne1 : {16, 1024}) {
66596645
test_cases.emplace_back(new test_soft_max_back(GGML_TYPE_F32, {ne0, ne1, 1, 1}, scale, max_bias));
66606646
test_cases.emplace_back(new test_soft_max_back(GGML_TYPE_F32, {ne0-1, ne1-1, 1, 1}, scale, max_bias));
6661-
test_cases.emplace_back(new test_soft_max_back(GGML_TYPE_F32, {ne0, ne1, 2, 3}, scale, max_bias));
66626647
}
66636648
}
66646649
}
@@ -6727,6 +6712,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
67276712
for (int dim : { 0, 1, 2, 3, }) {
67286713
test_cases.emplace_back(new test_concat(GGML_TYPE_F32, {11, 12, 13, 14}, 7, dim, v));
67296714
test_cases.emplace_back(new test_concat(GGML_TYPE_I32, {11, 12, 13, 14}, 7, dim, v));
6715+
test_cases.emplace_back(new test_concat(GGML_TYPE_I16, {11, 12, 13, 14}, 7, dim, v));
67306716
}
67316717
}
67326718

@@ -6784,8 +6770,8 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
67846770
test_cases.emplace_back(new test_pad_ext(GGML_TYPE_F32, {11, 22, 33, 44}, 1, 2, 3, 4, 5, 6, 7, 8, v));
67856771
}
67866772

6787-
for (int hsk : { 40, 64, 80, 96, 128, 192, 256, 576 }) {
6788-
for (int hsv : { 40, 64, 80, 96, 128, 192, 256, 512 }) {
6773+
for (int hsk : { 40, 64, 80, 128, 192, 256, 576 }) {
6774+
for (int hsv : { 40, 64, 80, 128, 192, 256, 512 }) {
67896775
if (hsk != 192 && hsk != 576 && hsk != hsv) continue;
67906776
if (hsk == 192 && (hsv != 128 && hsv != 192)) continue;
67916777
if (hsk == 576 && hsv != 512) continue; // DeepSeek MLA
@@ -7115,17 +7101,7 @@ static void list_all_ops() {
71157101
static void show_test_coverage() {
71167102
std::set<std::string> all_ops;
71177103
for (int i = 1; i < GGML_OP_COUNT; i++) {
7118-
auto op = (enum ggml_op)i;
7119-
if (op == GGML_OP_VIEW ||
7120-
op == GGML_OP_RESHAPE ||
7121-
op == GGML_OP_PERMUTE ||
7122-
op == GGML_OP_TRANSPOSE ||
7123-
op == GGML_OP_CONT ||
7124-
op == GGML_OP_GLU ||
7125-
op == GGML_OP_UNARY) {
7126-
continue;
7127-
}
7128-
all_ops.insert(ggml_op_name(op));
7104+
all_ops.insert(ggml_op_name((enum ggml_op)i));
71297105
}
71307106
for (int i = 0; i < GGML_UNARY_OP_COUNT; i++) {
71317107
all_ops.insert(ggml_unary_op_name((enum ggml_unary_op)i));

0 commit comments

Comments
 (0)