Skip to content

Commit 710405b

Browse files
tamarPaltamarPal
authored andcommitted
sycl: remove wait() calls from ROLL operation
1 parent 386db09 commit 710405b

File tree

1 file changed

+29
-23
lines changed

1 file changed

+29
-23
lines changed

ggml/src/ggml-sycl/roll.cpp

Lines changed: 29 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -9,19 +9,25 @@ static void kernel_roll_multi_axis(queue &q, const ggml_tensor *src, ggml_tensor
99
if (src->type != GGML_TYPE_F32 || dst->type != GGML_TYPE_F32)
1010
throw std::runtime_error("only F32 supported in SYCL roll");
1111

12-
const int64_t ne0 = dst->ne[0];
13-
const int64_t ne1 = dst->ne[1];
14-
const int64_t ne2 = dst->ne[2];
15-
const int64_t ne3 = dst->ne[3];
12+
const int ne0 = dst->ne[0];
13+
const int ne1 = dst->ne[1];
14+
const int ne2 = dst->ne[2];
15+
const int ne3 = dst->ne[3];
1616

1717
if (ne0 != src->ne[0] || ne1 != src->ne[1] || ne2 != src->ne[2] || ne3 != src->ne[3])
1818
throw std::runtime_error("src/dst shape mismatch");
1919

20-
// Normalize shifts to be within bounds
21-
const int64_t sh0 = ne0 > 0 ? ((int64_t)shift0 % ne0 + ne0) % ne0 : 0;
22-
const int64_t sh1 = ne1 > 0 ? ((int64_t)shift1 % ne1 + ne1) % ne1 : 0;
23-
const int64_t sh2 = ne2 > 0 ? ((int64_t)shift2 % ne2 + ne2) % ne2 : 0;
24-
const int64_t sh3 = ne3 > 0 ? ((int64_t)shift3 % ne3 + ne3) % ne3 : 0;
20+
21+
const int sh0 = ne0 > 0 ? ((int)shift0 % ne0 + ne0) % ne0 : 0;
22+
const int sh1 = ne1 > 0 ? ((int)shift1 % ne1 + ne1) % ne1 : 0;
23+
const int sh2 = ne2 > 0 ? ((int)shift2 % ne2 + ne2) % ne2 : 0;
24+
const int sh3 = ne3 > 0 ? ((int)shift3 % ne3 + ne3) % ne3 : 0;
25+
26+
27+
const int shNe0 = ne0 - sh0;
28+
const int shNe1 = ne1 - sh1;
29+
const int shNe2 = ne2 - sh2;
30+
const int shNe3 = ne3 - sh3;
2531

2632
const float *src_d = (const float*) src->data;
2733
float *dst_d = (float*) dst->data;
@@ -31,26 +37,26 @@ static void kernel_roll_multi_axis(queue &q, const ggml_tensor *src, ggml_tensor
3137
q.submit([&](handler &h) {
3238
range<3> r((size_t)ne3, (size_t)ne2, (size_t)ne1);
3339
h.parallel_for(r, [=](id<3> idx) {
34-
const int64_t i3 = idx[0];
35-
const int64_t i2 = idx[1];
36-
const int64_t i1 = idx[2];
40+
const int i3 = (int)idx[0];
41+
const int i2 = (int)idx[1];
42+
const int i1 = (int)idx[2];
3743

38-
for (int64_t i0 = 0; i0 < ne0; i0++) {
39-
const int64_t idx_dst = i0 + i1 * ne0 + i2 * ne0 * ne1 + i3 * ne0 * ne1 * ne2;
44+
for (int i0 = 0; i0 < ne0; i0++) {
45+
const int idx_dst = i0 + i1 * ne0 + i2 * ne0 * ne1 + i3 * ne0 * ne1 * ne2;
4046

41-
// Apply shift to each dimension
42-
const int64_t src_i0 = (i0 - sh0 + ne0) % ne0;
43-
const int64_t src_i1 = (i1 - sh1 + ne1) % ne1;
44-
const int64_t src_i2 = (i2 - sh2 + ne2) % ne2;
45-
const int64_t src_i3 = (i3 - sh3 + ne3) % ne3;
4647

47-
const int64_t idx_src = src_i0 + src_i1 * ne0 +
48+
const int src_i0 = (i0 + shNe0) % ne0;
49+
const int src_i1 = (i1 + shNe1) % ne1;
50+
const int src_i2 = (i2 + shNe2) % ne2;
51+
const int src_i3 = (i3 + shNe3) % ne3;
52+
53+
const int idx_src = src_i0 + src_i1 * ne0 +
4854
src_i2 * ne0 * ne1 + src_i3 * ne0 * ne1 * ne2;
4955

5056
dst_d[idx_dst] = src_d[idx_src];
5157
}
5258
});
53-
}).wait();
59+
});
5460
}
5561

5662
void ggml_sycl_roll(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
@@ -64,11 +70,11 @@ void ggml_sycl_roll(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
6470
const int shift2 = params[2];
6571
const int shift3 = params[3];
6672

67-
// Check if all shifts are zero
73+
6874
if (shift0 == 0 && shift1 == 0 && shift2 == 0 && shift3 == 0) {
6975
const size_t nb = ggml_nbytes(src);
7076
queue *q = ctx.stream();
71-
SYCL_CHECK(CHECK_TRY_ERROR(q->memcpy(dst->data, src->data, nb).wait()));
77+
SYCL_CHECK(CHECK_TRY_ERROR(q->memcpy(dst->data, src->data, nb)));
7278
return;
7379
}
7480

0 commit comments

Comments
 (0)