@@ -49,10 +49,7 @@ static void rope_norm(const T * x, T * dst, const int ne0, const int ne1, const
4949
5050 if (i0 >= n_dims) {
5151 const int i = row * ne0 + i0;
52-
53- dst[i + 0 ] = x[i + 0 ];
54- dst[i + 1 ] = x[i + 1 ];
55-
52+ *reinterpret_cast <sycl::vec<T, 2 > *>(dst + i) = *reinterpret_cast <const sycl::vec<T, 2 > *>(x + i);
5653 return ;
5754 }
5855
@@ -93,10 +90,7 @@ static void rope_neox(const T * x, T * dst, const int ne0, const int ne1, const
9390
9491 if (i0 >= n_dims) {
9592 const int i = row * ne0 + i0;
96-
97- dst[i + 0 ] = x[i + 0 ];
98- dst[i + 1 ] = x[i + 1 ];
99-
93+ *reinterpret_cast <sycl::vec<T, 2 > *>(dst + i) = *reinterpret_cast <const sycl::vec<T, 2 > *>(x + i);
10094 return ;
10195 }
10296
@@ -122,6 +116,63 @@ static void rope_neox(const T * x, T * dst, const int ne0, const int ne1, const
122116 dst[i + n_dims / 2 ] = x0 * sin_theta + x1 * cos_theta;
123117}
124118
119+ template <typename T, bool has_ff>
120+ static void rope_multi (const T * x, T * dst, const int ne0, const int ne1, const int ne2, const size_t s1,
121+ const size_t s2, const int n_dims, const int32_t * pos, const float freq_scale,
122+ const float ext_factor, const float attn_factor, const rope_corr_dims corr_dims,
123+ const float theta_scale, const float * freq_factors, const mrope_sections sections,
124+ const sycl::nd_item<3 > & item_ct1) {
125+ // get index pos
126+ const int i0 = 2 * (item_ct1.get_group (1 ) * item_ct1.get_local_range (1 ) + item_ct1.get_local_id (1 ));
127+ if (i0 >= ne0) {
128+ return ;
129+ }
130+ const int row_dst = (item_ct1.get_group (2 ) * item_ct1.get_local_range (2 )) + item_ct1.get_local_id (2 );
131+
132+ if (i0 >= n_dims) {
133+ const int i = row_dst*ne0 + i0;
134+ *reinterpret_cast <sycl::vec<T, 2 > *>(dst + i) = *reinterpret_cast <const sycl::vec<T, 2 > *>(x + i);
135+ return ;
136+ }
137+
138+ const int row_x = row_dst % ne1;
139+ const int channel_x = row_dst / ne1;
140+ const int idst = (row_dst * ne0) + (i0 / 2 );
141+ const size_t ix = ((size_t ) channel_x * s2) + ((size_t ) row_x * s1) + (i0 / 2 );
142+
143+ const int sect_dims = sections.v [0 ] + sections.v [1 ] + sections.v [2 ] + sections.v [3 ];
144+ const int sec_w = sections.v [1 ] + sections.v [0 ];
145+ const int sector = (i0 / 2 ) % sect_dims;
146+
147+
148+ float theta_base = 0.0 ;
149+ if (sector < sections.v [0 ]) {
150+ theta_base = pos[channel_x]*sycl::pow (theta_scale, i0/2 .0f );
151+ }
152+ else if (sector >= sections.v [0 ] && sector < sec_w) {
153+ theta_base = pos[channel_x + ne2 * 1 ]*sycl::pow (theta_scale, i0/2 .0f );
154+ }
155+ else if (sector >= sec_w && sector < sec_w + sections.v [2 ]) {
156+ theta_base = pos[channel_x + ne2 * 2 ]*sycl::pow (theta_scale, i0/2 .0f );
157+ }
158+ else if (sector >= sec_w + sections.v [2 ]) {
159+ theta_base = pos[channel_x + ne2 * 3 ]*sycl::pow (theta_scale, i0/2 .0f );
160+ }
161+
162+ const float freq_factor = has_ff ? freq_factors[i0 / 2 ] : 1 .0f ;
163+ float cos_theta;
164+ float sin_theta;
165+ rope_yarn (theta_base / freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
166+ const float x0 = x[ix + 0 ];
167+ const float x1 = x[ix + n_dims/2 ];
168+
169+ // store results in dst
170+ dst[idst + 0 ] = x0 * cos_theta - x1 * sin_theta;
171+ dst[idst + n_dims/2 ] = x0 * sin_theta + x1 * cos_theta;
172+ }
173+
174+
175+
125176template <typename T, bool has_ff>
126177static void rope_vision (const T * x, T * dst, const int ne0, const int ne1, const int ne2, const size_t s1,
127178 const size_t s2, const int n_dims, const int32_t * pos, const float freq_scale,
@@ -171,7 +222,7 @@ static void rope_norm_sycl(const T * x, T * dst, const int ne0, const int ne1, c
171222 const float * freq_factors, queue_ptr stream) {
172223 GGML_ASSERT (ne0 % 2 == 0 );
173224 const sycl::range<3 > block_dims (1 , SYCL_ROPE_BLOCK_SIZE, 1 );
174- const int num_blocks_x = (ne0 + 2 * SYCL_ROPE_BLOCK_SIZE - 1 ) / (2 * SYCL_ROPE_BLOCK_SIZE);
225+ const int num_blocks_x = ceil_div (ne0, (2 * SYCL_ROPE_BLOCK_SIZE) );
175226 const sycl::range<3 > block_nums (1 , num_blocks_x, nr);
176227
177228 const float theta_scale = powf (freq_base, -2 .0f / n_dims);
@@ -208,7 +259,7 @@ static void rope_neox_sycl(const T * x, T * dst, const int ne0, const int ne1, c
208259 const rope_corr_dims corr_dims, const float * freq_factors, queue_ptr stream) {
209260 GGML_ASSERT (ne0 % 2 == 0 );
210261 const sycl::range<3 > block_dims (1 , SYCL_ROPE_BLOCK_SIZE, 1 );
211- const int num_blocks_x = (ne0 + 2 * SYCL_ROPE_BLOCK_SIZE - 1 ) / (2 * SYCL_ROPE_BLOCK_SIZE);
262+ const int num_blocks_x = ceil_div (ne0, (2 * SYCL_ROPE_BLOCK_SIZE) );
212263 const sycl::range<3 > block_nums (1 , num_blocks_x, nr);
213264
214265 const float theta_scale = powf (freq_base, -2 .0f / n_dims);
@@ -228,6 +279,40 @@ static void rope_neox_sycl(const T * x, T * dst, const int ne0, const int ne1, c
228279 }
229280}
230281
282+ template <typename T>
283+ static void rope_multi_sycl (const T * x, T * dst, const int ne0, const int ne1, const int ne2, const size_t s1,
284+ const size_t s2, const int n_dims, const int nr, const int32_t * pos,
285+ const float freq_scale, const float freq_base, const float ext_factor,
286+ const float attn_factor, const rope_corr_dims corr_dims, const float * freq_factors,
287+ const mrope_sections sections, queue_ptr stream) {
288+ GGML_ASSERT (ne0 % 2 == 0 );
289+ const sycl::range<3 > block_dims (1 , SYCL_ROPE_BLOCK_SIZE, 1 );
290+ const int n_blocks_y = ceil_div (ne0, (2 * SYCL_ROPE_BLOCK_SIZE));
291+ const sycl::range<3 > grid_dims (1 , n_blocks_y, nr);
292+ const sycl::nd_range<3 > nd_range (grid_dims * block_dims, block_dims);
293+
294+ const float theta_scale = std::pow (freq_base, -2 .0f / n_dims);
295+ // Add FP16 capability check if T could be sycl::half
296+ if constexpr (std::is_same_v<T, sycl::half>) {
297+ dpct::has_capability_or_fail (stream->get_device (), { sycl::aspect::fp16 });
298+ }
299+ // launch kernel
300+ if (freq_factors == nullptr ) {
301+ stream->parallel_for (nd_range, [=](sycl::nd_item<3 > item_ct1) {
302+ rope_multi<T, false >(x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor,
303+ corr_dims, theta_scale, freq_factors, sections, item_ct1);
304+ });
305+ } else {
306+ stream->parallel_for (nd_range, [=](sycl::nd_item<3 > item_ct1) {
307+ rope_multi<T, true >(x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor,
308+ corr_dims, theta_scale, freq_factors, sections, item_ct1);
309+ });
310+ }
311+ }
312+
313+
314+
315+
231316// rope vision
232317template <typename T>
233318static void rope_vision_sycl (const T * x, T * dst, const int ne0, const int ne1, const int ne2, const size_t s1,
@@ -237,7 +322,7 @@ static void rope_vision_sycl(const T * x, T * dst, const int ne0, const int ne1,
237322 const mrope_sections sections, queue_ptr stream) {
238323 GGML_ASSERT (ne0 % 2 == 0 );
239324 const sycl::range<3 > block_dims (1 , SYCL_ROPE_BLOCK_SIZE, 1 );
240- const int n_blocks_y = (ne0 + 2 * SYCL_ROPE_BLOCK_SIZE - 1 ) / (2 * SYCL_ROPE_BLOCK_SIZE);
325+ const int n_blocks_y = ceil_div (ne0, (2 * SYCL_ROPE_BLOCK_SIZE) );
241326 const sycl::range<3 > grid_dims (1 , n_blocks_y, nr);
242327 const sycl::nd_range<3 > nd_range (grid_dims * block_dims, block_dims);
243328
@@ -298,8 +383,17 @@ inline void ggml_sycl_op_rope(ggml_backend_sycl_context & ctx, ggml_tensor *dst)
298383 memcpy (§ions.v , (int32_t *) dst->op_params + 11 , sizeof (int )*4 );
299384
300385 const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
386+ const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
301387 const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
302388
389+ if (is_mrope) {
390+ GGML_ASSERT (sections.v [0 ] > 0 || sections.v [1 ] > 0 || sections.v [2 ] > 0 );
391+ }
392+
393+ if (is_vision) {
394+ GGML_ASSERT (n_dims == ne00/2 );
395+ }
396+
303397 const int32_t * pos = (const int32_t *) dst->src [1 ]->data ;
304398
305399 const float * freq_factors = nullptr ;
@@ -326,6 +420,19 @@ inline void ggml_sycl_op_rope(ggml_backend_sycl_context & ctx, ggml_tensor *dst)
326420 } else {
327421 GGML_ABORT (" fatal error" );
328422 }
423+ } else if (is_mrope && !is_vision) {
424+ GGML_SYCL_DEBUG (" %s: mrope path\n " , __func__);
425+ if (dst->src [0 ]->type == GGML_TYPE_F16) {
426+ rope_multi_sycl ((const sycl::half *)dst->src [0 ]->data , (sycl::half *)dst->data , ne00, ne01, ne02, s01,
427+ s02, n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
428+ freq_factors, sections, main_stream);
429+ } else if (dst->src [0 ]->type == GGML_TYPE_F32) {
430+ rope_multi_sycl ((const float *) dst->src [0 ]->data , (float *) dst->data , ne00, ne01, ne02, s01, s02, n_dims,
431+ nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections,
432+ main_stream);
433+ } else {
434+ GGML_ABORT (" Fatal error: Tensor type unsupported!" );
435+ }
329436 } else if (is_vision) {
330437 GGML_SYCL_DEBUG (" %s: vision path\n " , __func__);
331438 if (dst->src [0 ]->type == GGML_TYPE_F16) {
0 commit comments