@@ -47,18 +47,18 @@ static void rope_norm(const T * x, T * dst, const int ne0, const int ne1, const
4747
4848 const int row = item_ct1.get_local_range (2 ) * item_ct1.get_group (2 ) + item_ct1.get_local_id (2 );
4949
50- if (i0 >= n_dims) {
51- const int i = row * ne0 + i0;
52- *reinterpret_cast <sycl::vec<T, 2 > *>(dst + i) = *reinterpret_cast <const sycl::vec<T, 2 > *>(x + i);
53- return ;
54- }
55-
5650 const int row0 = row % ne1;
5751 const int channel0 = row / ne1;
5852
5953 const int i = row * ne0 + i0;
6054 const int i2 = channel0 * s2 + row0 * s1 + i0;
6155
56+ if (i0 >= n_dims) {
57+ const int i = row * ne0 + i0;
58+ *reinterpret_cast <sycl::vec<T, 2 > *>(dst + i) = *reinterpret_cast <const sycl::vec<T, 2 > *>(x + i2);
59+ return ;
60+ }
61+
6262 const float theta_base = pos[channel0] * sycl::pow (theta_scale, i0 / 2 .0f );
6363
6464 const float freq_factor = has_ff ? freq_factors[i0 / 2 ] : 1 .0f ;
@@ -88,18 +88,18 @@ static void rope_neox(const T * x, T * dst, const int ne0, const int ne1, const
8888
8989 const int row = item_ct1.get_local_range (2 ) * item_ct1.get_group (2 ) + item_ct1.get_local_id (2 );
9090
91- if (i0 >= n_dims) {
92- const int i = row * ne0 + i0;
93- *reinterpret_cast <sycl::vec<T, 2 > *>(dst + i) = *reinterpret_cast <const sycl::vec<T, 2 > *>(x + i);
94- return ;
95- }
96-
9791 const int row0 = row % ne1;
9892 const int channel0 = row / ne1;
9993
10094 const int i = row * ne0 + i0 / 2 ;
10195 const int i2 = channel0 * s2 + row0 * s1 + i0 / 2 ;
10296
97+ if (i0 >= n_dims) {
98+ const int i = row * ne0 + i0;
99+ *reinterpret_cast <sycl::vec<T, 2 > *>(dst + i) = *reinterpret_cast <const sycl::vec<T, 2 > *>(x + i0 / 2 + i);
100+ return ;
101+ }
102+
103103 const float theta_base = pos[channel0] * sycl::pow (theta_scale, i0 / 2 .0f );
104104
105105 const float freq_factor = has_ff ? freq_factors[i0 / 2 ] : 1 .0f ;
@@ -129,17 +129,17 @@ static void rope_multi(const T * x, T * dst, const int ne0, const int ne1, const
129129 }
130130 const int row_dst = (item_ct1.get_group (2 ) * item_ct1.get_local_range (2 )) + item_ct1.get_local_id (2 );
131131
132- if (i0 >= n_dims) {
133- const int i = row_dst*ne0 + i0;
134- *reinterpret_cast <sycl::vec<T, 2 > *>(dst + i) = *reinterpret_cast <const sycl::vec<T, 2 > *>(x + i);
135- return ;
136- }
137-
138132 const int row_x = row_dst % ne1;
139133 const int channel_x = row_dst / ne1;
140134 const int idst = (row_dst * ne0) + (i0 / 2 );
141135 const size_t ix = ((size_t ) channel_x * s2) + ((size_t ) row_x * s1) + (i0 / 2 );
142136
137+ if (i0 >= n_dims) {
138+ const int i = row_dst*ne0 + i0;
139+ *reinterpret_cast <sycl::vec<T, 2 > *>(dst + i) = *reinterpret_cast <const sycl::vec<T, 2 > *>(x + i0 / 2 + i);
140+ return ;
141+ }
142+
143143 const int sect_dims = sections.v [0 ] + sections.v [1 ] + sections.v [2 ] + sections.v [3 ];
144144 const int sec_w = sections.v [1 ] + sections.v [0 ];
145145 const int sector = (i0 / 2 ) % sect_dims;
0 commit comments