33
44using namespace sycl ;
55
6- static void kernel_roll_multi_axis (queue &q, const ggml_tensor *src, ggml_tensor *dst,
7- int shift0, int shift1, int shift2, int shift3) {
8- if (!src || !dst) throw std::runtime_error (" null tensor" );
9- if (src->type != GGML_TYPE_F32 || dst->type != GGML_TYPE_F32)
10- throw std::runtime_error (" only F32 supported in SYCL roll" );
6+ static inline int wrap_add (int i, int shift, int n) {
117
12- const int ne0 = dst->ne [0 ];
13- const int ne1 = dst->ne [1 ];
14- const int ne2 = dst->ne [2 ];
15- const int ne3 = dst->ne [3 ];
8+ int s = i + shift;
9+ return (s >= n) ? (s - n) : s;
10+ }
11+
12+ static void kernel_roll_fused_i0_i1 (
13+ queue &q,
14+ const float *src_d,
15+ float *dst_d,
16+ int ne0, int ne1, int ne2, int ne3,
17+ int sh0, int sh1, int sh2, int sh3)
18+ {
19+ if (ne0 == 0 || ne1 == 0 || ne2 == 0 || ne3 == 0 ) return ;
1620
17- if (ne0 != src->ne [0 ] || ne1 != src->ne [1 ] || ne2 != src->ne [2 ] || ne3 != src->ne [3 ])
18- throw std::runtime_error (" src/dst shape mismatch" );
1921
22+ const int stride1 = ne0;
23+ const int stride2 = ne0 * ne1;
24+ const int stride3 = ne0 * ne1 * ne2;
2025
21- const int sh0 = ne0 > 0 ? ((int )shift0 % ne0 + ne0) % ne0 : 0 ;
22- const int sh1 = ne1 > 0 ? ((int )shift1 % ne1 + ne1) % ne1 : 0 ;
23- const int sh2 = ne2 > 0 ? ((int )shift2 % ne2 + ne2) % ne2 : 0 ;
24- const int sh3 = ne3 > 0 ? ((int )shift3 % ne3 + ne3) % ne3 : 0 ;
2526
27+ const int shNe0 = (ne0 - sh0) % ne0;
28+ const int shNe1 = (ne1 - sh1) % ne1;
29+ const int shNe2 = (ne2 - sh2) % ne2;
30+ const int shNe3 = (ne3 - sh3) % ne3;
2631
27- const int shNe0 = ne0 - sh0;
28- const int shNe1 = ne1 - sh1;
29- const int shNe2 = ne2 - sh2;
30- const int shNe3 = ne3 - sh3;
3132
32- const float *src_d = (const float *) src->data ;
33- float *dst_d = (float *) dst->data ;
33+ const size_t g0 = (size_t ) ne3;
34+ const size_t g1 = (size_t ) ne2;
35+ const size_t g2 = (size_t ) (ne1 * ne0);
3436
35- if (!src_d || !dst_d) throw std::runtime_error ( " null data pointers " ) ;
37+ const range< 3 > global{ g0, g1, g2 } ;
3638
3739 q.submit ([&](handler &h) {
38- range<3 > r ((size_t )ne3, (size_t )ne2, (size_t )ne1);
39- h.parallel_for (r, [=](id<3 > idx) {
40- const int i3 = (int )idx[0 ];
41- const int i2 = (int )idx[1 ];
42- const int i1 = (int )idx[2 ];
40+ h.parallel_for (global, [=](id<3 > idx) {
41+ const int i3 = (int ) idx[0 ];
42+ const int i2 = (int ) idx[1 ];
43+
44+ const int fused = (int ) idx[2 ];
45+ const int i1 = fused / ne0;
46+ const int i0 = fused - i1 * ne0; // fused % ne0
47+
4348
44- for (int i0 = 0 ; i0 < ne0; i0++) {
45- const int idx_dst = i0 + i1 * ne0 + i2 * ne0 * ne1 + i3 * ne0 * ne1 * ne2;
49+ const int idx_dst = i0
50+ + i1 * stride1
51+ + i2 * stride2
52+ + i3 * stride3;
4653
4754
48- const int src_i0 = (i0 + shNe0) % ne0;
49- const int src_i1 = (i1 + shNe1) % ne1;
50- const int src_i2 = (i2 + shNe2) % ne2;
51- const int src_i3 = (i3 + shNe3) % ne3;
55+ const int s0 = wrap_add (i0, shNe0, ne0) ;
56+ const int s1 = wrap_add (i1, shNe1, ne1) ;
57+ const int s2 = wrap_add (i2, shNe2, ne2) ;
58+ const int s3 = wrap_add (i3, shNe3, ne3) ;
5259
53- const int idx_src = src_i0 + src_i1 * ne0 +
54- src_i2 * ne0 * ne1 + src_i3 * ne0 * ne1 * ne2;
60+ const int idx_src = s0
61+ + s1 * stride1
62+ + s2 * stride2
63+ + s3 * stride3;
5564
56- dst_d[idx_dst] = src_d[idx_src];
57- }
65+ dst_d[idx_dst] = src_d[idx_src];
5866 });
5967 });
6068}
@@ -63,24 +71,50 @@ void ggml_sycl_roll(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
6371 GGML_ASSERT (dst->type == GGML_TYPE_F32);
6472
6573 const ggml_tensor *src = dst->src [0 ];
74+ GGML_ASSERT (src && src->type == GGML_TYPE_F32);
6675
67- const int32_t *params = (const int32_t *)dst->op_params ;
68- const int shift0 = params[0 ];
69- const int shift1 = params[1 ];
70- const int shift2 = params[2 ];
71- const int shift3 = params[3 ];
76+ const int ne0 = (int ) dst->ne [0 ];
77+ const int ne1 = (int ) dst->ne [1 ];
78+ const int ne2 = (int ) dst->ne [2 ];
79+ const int ne3 = (int ) dst->ne [3 ];
7280
81+ const int32_t *params = (const int32_t *) dst->op_params ;
82+ int shift0 = params[0 ];
83+ int shift1 = params[1 ];
84+ int shift2 = params[2 ];
85+ int shift3 = params[3 ];
7386
74- if (shift0 == 0 && shift1 == 0 && shift2 == 0 && shift3 == 0 ) {
87+
88+ if ((shift0 | shift1 | shift2 | shift3) == 0 ) {
7589 const size_t nb = ggml_nbytes (src);
7690 queue *q = ctx.stream ();
7791 SYCL_CHECK (CHECK_TRY_ERROR (q->memcpy (dst->data , src->data , nb)));
7892 return ;
7993 }
8094
95+ auto norm = [](int sh, int n) -> int {
96+ if (n <= 0 ) return 0 ;
97+ sh %= n;
98+ if (sh < 0 ) sh += n;
99+ return sh;
100+ };
101+ shift0 = norm (shift0, ne0);
102+ shift1 = norm (shift1, ne1);
103+ shift2 = norm (shift2, ne2);
104+ shift3 = norm (shift3, ne3);
105+
81106 try {
82107 queue *q = ctx.stream ();
83- kernel_roll_multi_axis (*q, src, dst, shift0, shift1, shift2, shift3);
108+
109+ const float *src_d = (const float *) src->data ;
110+ float *dst_d = (float *) dst->data ;
111+ GGML_ASSERT (src_d && dst_d);
112+
113+ kernel_roll_fused_i0_i1 (
114+ *q, src_d, dst_d,
115+ ne0, ne1, ne2, ne3,
116+ shift0, shift1, shift2, shift3
117+ );
84118 } catch (const std::exception &e) {
85119 std::fprintf (stderr, " [SYCL-ROLL] ERROR: %s\n " , e.what ());
86120 throw ;
0 commit comments