@@ -47,18 +47,17 @@ static void rope_norm(const T * x, T * dst, const int ne0, const int ne1, const
47
47
48
48
const int row = item_ct1.get_local_range (2 ) * item_ct1.get_group (2 ) + item_ct1.get_local_id (2 );
49
49
50
- if (i0 >= n_dims) {
51
- const int i = row * ne0 + i0;
52
- *reinterpret_cast <sycl::vec<T, 2 > *>(dst + i) = *reinterpret_cast <const sycl::vec<T, 2 > *>(x + i);
53
- return ;
54
- }
55
-
56
50
const int row0 = row % ne1;
57
51
const int channel0 = row / ne1;
58
52
59
53
const int i = row * ne0 + i0;
60
54
const int i2 = channel0 * s2 + row0 * s1 + i0;
61
55
56
+ if (i0 >= n_dims) {
57
+ *reinterpret_cast <sycl::vec<T, 2 > *>(dst + i) = *reinterpret_cast <const sycl::vec<T, 2 > *>(x + i2);
58
+ return ;
59
+ }
60
+
62
61
const float theta_base = pos[channel0] * sycl::pow (theta_scale, i0 / 2 .0f );
63
62
64
63
const float freq_factor = has_ff ? freq_factors[i0 / 2 ] : 1 .0f ;
@@ -88,18 +87,17 @@ static void rope_neox(const T * x, T * dst, const int ne0, const int ne1, const
88
87
89
88
const int row = item_ct1.get_local_range (2 ) * item_ct1.get_group (2 ) + item_ct1.get_local_id (2 );
90
89
91
- if (i0 >= n_dims) {
92
- const int i = row * ne0 + i0;
93
- *reinterpret_cast <sycl::vec<T, 2 > *>(dst + i) = *reinterpret_cast <const sycl::vec<T, 2 > *>(x + i);
94
- return ;
95
- }
96
-
97
90
const int row0 = row % ne1;
98
91
const int channel0 = row / ne1;
99
92
100
93
const int i = row * ne0 + i0 / 2 ;
101
94
const int i2 = channel0 * s2 + row0 * s1 + i0 / 2 ;
102
95
96
+ if (i0 >= n_dims) {
97
+ *reinterpret_cast <sycl::vec<T, 2 > *>(dst + i + i0 / 2 ) = *reinterpret_cast <const sycl::vec<T, 2 > *>(x + i2 + i0 / 2 );
98
+ return ;
99
+ }
100
+
103
101
const float theta_base = pos[channel0] * sycl::pow (theta_scale, i0 / 2 .0f );
104
102
105
103
const float freq_factor = has_ff ? freq_factors[i0 / 2 ] : 1 .0f ;
@@ -129,17 +127,16 @@ static void rope_multi(const T * x, T * dst, const int ne0, const int ne1, const
129
127
}
130
128
const int row_dst = (item_ct1.get_group (2 ) * item_ct1.get_local_range (2 )) + item_ct1.get_local_id (2 );
131
129
132
- if (i0 >= n_dims) {
133
- const int i = row_dst*ne0 + i0;
134
- *reinterpret_cast <sycl::vec<T, 2 > *>(dst + i) = *reinterpret_cast <const sycl::vec<T, 2 > *>(x + i);
135
- return ;
136
- }
137
-
138
130
const int row_x = row_dst % ne1;
139
131
const int channel_x = row_dst / ne1;
140
132
const int idst = (row_dst * ne0) + (i0 / 2 );
141
133
const size_t ix = ((size_t ) channel_x * s2) + ((size_t ) row_x * s1) + (i0 / 2 );
142
134
135
+ if (i0 >= n_dims) {
136
+ *reinterpret_cast <sycl::vec<T, 2 > *>(dst + idst + i0 / 2 ) = *reinterpret_cast <const sycl::vec<T, 2 > *>(x + i0 / 2 + ix);
137
+ return ;
138
+ }
139
+
143
140
const int sect_dims = sections.v [0 ] + sections.v [1 ] + sections.v [2 ] + sections.v [3 ];
144
141
const int sec_w = sections.v [1 ] + sections.v [0 ];
145
142
const int sector = (i0 / 2 ) % sect_dims;
0 commit comments