@@ -9,19 +9,25 @@ static void kernel_roll_multi_axis(queue &q, const ggml_tensor *src, ggml_tensor
99 if (src->type != GGML_TYPE_F32 || dst->type != GGML_TYPE_F32)
1010 throw std::runtime_error (" only F32 supported in SYCL roll" );
1111
12- const int64_t ne0 = dst->ne [0 ];
13- const int64_t ne1 = dst->ne [1 ];
14- const int64_t ne2 = dst->ne [2 ];
15- const int64_t ne3 = dst->ne [3 ];
12+ const int ne0 = dst->ne [0 ];
13+ const int ne1 = dst->ne [1 ];
14+ const int ne2 = dst->ne [2 ];
15+ const int ne3 = dst->ne [3 ];
1616
1717 if (ne0 != src->ne [0 ] || ne1 != src->ne [1 ] || ne2 != src->ne [2 ] || ne3 != src->ne [3 ])
1818 throw std::runtime_error (" src/dst shape mismatch" );
1919
20- // Normalize shifts to be within bounds
21- const int64_t sh0 = ne0 > 0 ? ((int64_t )shift0 % ne0 + ne0) % ne0 : 0 ;
22- const int64_t sh1 = ne1 > 0 ? ((int64_t )shift1 % ne1 + ne1) % ne1 : 0 ;
23- const int64_t sh2 = ne2 > 0 ? ((int64_t )shift2 % ne2 + ne2) % ne2 : 0 ;
24- const int64_t sh3 = ne3 > 0 ? ((int64_t )shift3 % ne3 + ne3) % ne3 : 0 ;
20+
21+ const int sh0 = ne0 > 0 ? ((int )shift0 % ne0 + ne0) % ne0 : 0 ;
22+ const int sh1 = ne1 > 0 ? ((int )shift1 % ne1 + ne1) % ne1 : 0 ;
23+ const int sh2 = ne2 > 0 ? ((int )shift2 % ne2 + ne2) % ne2 : 0 ;
24+ const int sh3 = ne3 > 0 ? ((int )shift3 % ne3 + ne3) % ne3 : 0 ;
25+
26+
27+ const int shNe0 = ne0 - sh0;
28+ const int shNe1 = ne1 - sh1;
29+ const int shNe2 = ne2 - sh2;
30+ const int shNe3 = ne3 - sh3;
2531
2632 const float *src_d = (const float *) src->data ;
2733 float *dst_d = (float *) dst->data ;
@@ -31,26 +37,26 @@ static void kernel_roll_multi_axis(queue &q, const ggml_tensor *src, ggml_tensor
3137 q.submit ([&](handler &h) {
3238 range<3 > r ((size_t )ne3, (size_t )ne2, (size_t )ne1);
3339 h.parallel_for (r, [=](id<3 > idx) {
34- const int64_t i3 = idx[0 ];
35- const int64_t i2 = idx[1 ];
36- const int64_t i1 = idx[2 ];
40+ const int i3 = ( int ) idx[0 ];
41+ const int i2 = ( int ) idx[1 ];
42+ const int i1 = ( int ) idx[2 ];
3743
38- for (int64_t i0 = 0 ; i0 < ne0; i0++) {
39- const int64_t idx_dst = i0 + i1 * ne0 + i2 * ne0 * ne1 + i3 * ne0 * ne1 * ne2;
44+ for (int i0 = 0 ; i0 < ne0; i0++) {
45+ const int idx_dst = i0 + i1 * ne0 + i2 * ne0 * ne1 + i3 * ne0 * ne1 * ne2;
4046
41- // Apply shift to each dimension
42- const int64_t src_i0 = (i0 - sh0 + ne0) % ne0;
43- const int64_t src_i1 = (i1 - sh1 + ne1) % ne1;
44- const int64_t src_i2 = (i2 - sh2 + ne2) % ne2;
45- const int64_t src_i3 = (i3 - sh3 + ne3) % ne3;
4647
47- const int64_t idx_src = src_i0 + src_i1 * ne0 +
48+ const int src_i0 = (i0 + shNe0) % ne0;
49+ const int src_i1 = (i1 + shNe1) % ne1;
50+ const int src_i2 = (i2 + shNe2) % ne2;
51+ const int src_i3 = (i3 + shNe3) % ne3;
52+
53+ const int idx_src = src_i0 + src_i1 * ne0 +
4854 src_i2 * ne0 * ne1 + src_i3 * ne0 * ne1 * ne2;
4955
5056 dst_d[idx_dst] = src_d[idx_src];
5157 }
5258 });
53- }). wait () ;
59+ });
5460}
5561
5662void ggml_sycl_roll (ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
@@ -64,11 +70,11 @@ void ggml_sycl_roll(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
6470 const int shift2 = params[2 ];
6571 const int shift3 = params[3 ];
6672
67- // Check if all shifts are zero
73+
6874 if (shift0 == 0 && shift1 == 0 && shift2 == 0 && shift3 == 0 ) {
6975 const size_t nb = ggml_nbytes (src);
7076 queue *q = ctx.stream ();
71- SYCL_CHECK (CHECK_TRY_ERROR (q->memcpy (dst->data , src->data , nb). wait () ));
77+ SYCL_CHECK (CHECK_TRY_ERROR (q->memcpy (dst->data , src->data , nb)));
7278 return ;
7379 }
7480
0 commit comments