Skip to content

Commit b979308

Browse files
tamarPaltamarPal
authored andcommitted
fix: editorconfig — LF endings + final newline for roll.hpp
1 parent 710405b commit b979308

File tree

2 files changed

+79
-45
lines changed

2 files changed

+79
-45
lines changed

ggml/src/ggml-sycl/roll.cpp

Lines changed: 78 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -3,58 +3,66 @@
33

44
using namespace sycl;
55

6-
static void kernel_roll_multi_axis(queue &q, const ggml_tensor *src, ggml_tensor *dst,
7-
int shift0, int shift1, int shift2, int shift3) {
8-
if (!src || !dst) throw std::runtime_error("null tensor");
9-
if (src->type != GGML_TYPE_F32 || dst->type != GGML_TYPE_F32)
10-
throw std::runtime_error("only F32 supported in SYCL roll");
6+
static inline int wrap_add(int i, int shift, int n) {
117

12-
const int ne0 = dst->ne[0];
13-
const int ne1 = dst->ne[1];
14-
const int ne2 = dst->ne[2];
15-
const int ne3 = dst->ne[3];
8+
int s = i + shift;
9+
return (s >= n) ? (s - n) : s;
10+
}
11+
12+
static void kernel_roll_fused_i0_i1(
13+
queue &q,
14+
const float *src_d,
15+
float *dst_d,
16+
int ne0, int ne1, int ne2, int ne3,
17+
int sh0, int sh1, int sh2, int sh3)
18+
{
19+
if (ne0 == 0 || ne1 == 0 || ne2 == 0 || ne3 == 0) return;
1620

17-
if (ne0 != src->ne[0] || ne1 != src->ne[1] || ne2 != src->ne[2] || ne3 != src->ne[3])
18-
throw std::runtime_error("src/dst shape mismatch");
1921

22+
const int stride1 = ne0;
23+
const int stride2 = ne0 * ne1;
24+
const int stride3 = ne0 * ne1 * ne2;
2025

21-
const int sh0 = ne0 > 0 ? ((int)shift0 % ne0 + ne0) % ne0 : 0;
22-
const int sh1 = ne1 > 0 ? ((int)shift1 % ne1 + ne1) % ne1 : 0;
23-
const int sh2 = ne2 > 0 ? ((int)shift2 % ne2 + ne2) % ne2 : 0;
24-
const int sh3 = ne3 > 0 ? ((int)shift3 % ne3 + ne3) % ne3 : 0;
2526

27+
const int shNe0 = (ne0 - sh0) % ne0;
28+
const int shNe1 = (ne1 - sh1) % ne1;
29+
const int shNe2 = (ne2 - sh2) % ne2;
30+
const int shNe3 = (ne3 - sh3) % ne3;
2631

27-
const int shNe0 = ne0 - sh0;
28-
const int shNe1 = ne1 - sh1;
29-
const int shNe2 = ne2 - sh2;
30-
const int shNe3 = ne3 - sh3;
3132

32-
const float *src_d = (const float*) src->data;
33-
float *dst_d = (float*) dst->data;
33+
const size_t g0 = (size_t) ne3;
34+
const size_t g1 = (size_t) ne2;
35+
const size_t g2 = (size_t) (ne1 * ne0);
3436

35-
if (!src_d || !dst_d) throw std::runtime_error("null data pointers");
37+
const range<3> global{ g0, g1, g2 };
3638

3739
q.submit([&](handler &h) {
38-
range<3> r((size_t)ne3, (size_t)ne2, (size_t)ne1);
39-
h.parallel_for(r, [=](id<3> idx) {
40-
const int i3 = (int)idx[0];
41-
const int i2 = (int)idx[1];
42-
const int i1 = (int)idx[2];
40+
h.parallel_for(global, [=](id<3> idx) {
41+
const int i3 = (int) idx[0];
42+
const int i2 = (int) idx[1];
43+
44+
const int fused = (int) idx[2];
45+
const int i1 = fused / ne0;
46+
const int i0 = fused - i1 * ne0; // fused % ne0
47+
4348

44-
for (int i0 = 0; i0 < ne0; i0++) {
45-
const int idx_dst = i0 + i1 * ne0 + i2 * ne0 * ne1 + i3 * ne0 * ne1 * ne2;
49+
const int idx_dst = i0
50+
+ i1 * stride1
51+
+ i2 * stride2
52+
+ i3 * stride3;
4653

4754

48-
const int src_i0 = (i0 + shNe0) % ne0;
49-
const int src_i1 = (i1 + shNe1) % ne1;
50-
const int src_i2 = (i2 + shNe2) % ne2;
51-
const int src_i3 = (i3 + shNe3) % ne3;
55+
const int s0 = wrap_add(i0, shNe0, ne0);
56+
const int s1 = wrap_add(i1, shNe1, ne1);
57+
const int s2 = wrap_add(i2, shNe2, ne2);
58+
const int s3 = wrap_add(i3, shNe3, ne3);
5259

53-
const int idx_src = src_i0 + src_i1 * ne0 +
54-
src_i2 * ne0 * ne1 + src_i3 * ne0 * ne1 * ne2;
60+
const int idx_src = s0
61+
+ s1 * stride1
62+
+ s2 * stride2
63+
+ s3 * stride3;
5564

56-
dst_d[idx_dst] = src_d[idx_src];
57-
}
65+
dst_d[idx_dst] = src_d[idx_src];
5866
});
5967
});
6068
}
@@ -63,24 +71,50 @@ void ggml_sycl_roll(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
6371
GGML_ASSERT(dst->type == GGML_TYPE_F32);
6472

6573
const ggml_tensor *src = dst->src[0];
74+
GGML_ASSERT(src && src->type == GGML_TYPE_F32);
6675

67-
const int32_t *params = (const int32_t *)dst->op_params;
68-
const int shift0 = params[0];
69-
const int shift1 = params[1];
70-
const int shift2 = params[2];
71-
const int shift3 = params[3];
76+
const int ne0 = (int) dst->ne[0];
77+
const int ne1 = (int) dst->ne[1];
78+
const int ne2 = (int) dst->ne[2];
79+
const int ne3 = (int) dst->ne[3];
7280

81+
const int32_t *params = (const int32_t *) dst->op_params;
82+
int shift0 = params[0];
83+
int shift1 = params[1];
84+
int shift2 = params[2];
85+
int shift3 = params[3];
7386

74-
if (shift0 == 0 && shift1 == 0 && shift2 == 0 && shift3 == 0) {
87+
88+
if ((shift0 | shift1 | shift2 | shift3) == 0) {
7589
const size_t nb = ggml_nbytes(src);
7690
queue *q = ctx.stream();
7791
SYCL_CHECK(CHECK_TRY_ERROR(q->memcpy(dst->data, src->data, nb)));
7892
return;
7993
}
8094

95+
auto norm = [](int sh, int n) -> int {
96+
if (n <= 0) return 0;
97+
sh %= n;
98+
if (sh < 0) sh += n;
99+
return sh;
100+
};
101+
shift0 = norm(shift0, ne0);
102+
shift1 = norm(shift1, ne1);
103+
shift2 = norm(shift2, ne2);
104+
shift3 = norm(shift3, ne3);
105+
81106
try {
82107
queue *q = ctx.stream();
83-
kernel_roll_multi_axis(*q, src, dst, shift0, shift1, shift2, shift3);
108+
109+
const float *src_d = (const float *) src->data;
110+
float *dst_d = (float *) dst->data;
111+
GGML_ASSERT(src_d && dst_d);
112+
113+
kernel_roll_fused_i0_i1(
114+
*q, src_d, dst_d,
115+
ne0, ne1, ne2, ne3,
116+
shift0, shift1, shift2, shift3
117+
);
84118
} catch (const std::exception &e) {
85119
std::fprintf(stderr, "[SYCL-ROLL] ERROR: %s\n", e.what());
86120
throw;

ggml/src/ggml-sycl/roll.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,4 +17,4 @@
1717

1818
void ggml_sycl_roll(ggml_backend_sycl_context & ctx, ggml_tensor *dst);
1919

20-
#endif // GGML_SYCL_ROLL_HPP
20+
#endif // GGML_SYCL_ROLL_HPP

0 commit comments

Comments
 (0)