@@ -119,7 +119,7 @@ static void rope_multi(const T * x, T * dst, const int ne0, const int ne1, const
119119 const size_t s2, const int n_dims, const int32_t * pos, const float freq_scale,
120120 const float ext_factor, const float attn_factor, const rope_corr_dims corr_dims,
121121 const float theta_scale, const float * freq_factors, const mrope_sections sections,
122- const sycl::nd_item<3 > & item_ct1) {
122+ const bool is_imrope, const sycl::nd_item<3 > & item_ct1) {
123123 // get index pos
124124 const int i0 = 2 * (item_ct1.get_group (1 ) * item_ct1.get_local_range (1 ) + item_ct1.get_local_id (1 ));
125125 if (i0 >= ne0) {
@@ -143,17 +143,29 @@ static void rope_multi(const T * x, T * dst, const int ne0, const int ne1, const
143143
144144
145145 float theta_base = 0.0 ;
146- if (sector < sections.v [0 ]) {
147- theta_base = pos[channel_x]*sycl::pow (theta_scale, i0/2 .0f );
148- }
149- else if (sector >= sections.v [0 ] && sector < sec_w) {
150- theta_base = pos[channel_x + ne2 * 1 ]*sycl::pow (theta_scale, i0/2 .0f );
151- }
152- else if (sector >= sec_w && sector < sec_w + sections.v [2 ]) {
153- theta_base = pos[channel_x + ne2 * 2 ]*sycl::pow (theta_scale, i0/2 .0f );
154- }
155- else if (sector >= sec_w + sections.v [2 ]) {
156- theta_base = pos[channel_x + ne2 * 3 ]*sycl::pow (theta_scale, i0/2 .0f );
146+ if (is_imrope) {
147+ if (sector % 3 == 1 && sector < 3 * sections.v [1 ]) {
148+ theta_base = pos[channel_x + ne2 * 1 ]*sycl::pow (theta_scale, i0/2 .0f );
149+ } else if (sector % 3 == 2 && sector < 3 * sections.v [2 ]) {
150+ theta_base = pos[channel_x + ne2 * 2 ]*sycl::pow (theta_scale, i0/2 .0f );
151+ } else if (sector % 3 == 0 && sector < 3 * sections.v [0 ]) {
152+ theta_base = pos[channel_x]*sycl::pow (theta_scale, i0/2 .0f );
153+ } else {
154+ theta_base = pos[channel_x + ne2 * 3 ]*sycl::pow (theta_scale, i0/2 .0f );
155+ }
156+ } else {
157+ if (sector < sections.v [0 ]) {
158+ theta_base = pos[channel_x]*sycl::pow (theta_scale, i0/2 .0f );
159+ }
160+ else if (sector >= sections.v [0 ] && sector < sec_w) {
161+ theta_base = pos[channel_x + ne2 * 1 ]*sycl::pow (theta_scale, i0/2 .0f );
162+ }
163+ else if (sector >= sec_w && sector < sec_w + sections.v [2 ]) {
164+ theta_base = pos[channel_x + ne2 * 2 ]*sycl::pow (theta_scale, i0/2 .0f );
165+ }
166+ else if (sector >= sec_w + sections.v [2 ]) {
167+ theta_base = pos[channel_x + ne2 * 3 ]*sycl::pow (theta_scale, i0/2 .0f );
168+ }
157169 }
158170
159171 const float freq_factor = has_ff ? freq_factors[i0 / 2 ] : 1 .0f ;
@@ -281,7 +293,7 @@ static void rope_multi_sycl(const T * x, T * dst, const int ne0, const int ne1,
281293 const size_t s2, const int n_dims, const int nr, const int32_t * pos,
282294 const float freq_scale, const float freq_base, const float ext_factor,
283295 const float attn_factor, const rope_corr_dims corr_dims, const float * freq_factors,
284- const mrope_sections sections, queue_ptr stream) {
296+ const mrope_sections sections, const bool is_imrope, queue_ptr stream) {
285297 GGML_ASSERT (ne0 % 2 == 0 );
286298 const sycl::range<3 > block_dims (1 , SYCL_ROPE_BLOCK_SIZE, 1 );
287299 const int n_blocks_y = ceil_div (ne0, (2 * SYCL_ROPE_BLOCK_SIZE));
@@ -297,12 +309,12 @@ static void rope_multi_sycl(const T * x, T * dst, const int ne0, const int ne1,
297309 if (freq_factors == nullptr ) {
298310 stream->parallel_for (nd_range, [=](sycl::nd_item<3 > item_ct1) {
299311 rope_multi<T, false >(x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor,
300- corr_dims, theta_scale, freq_factors, sections, item_ct1);
312+ corr_dims, theta_scale, freq_factors, sections, is_imrope, item_ct1);
301313 });
302314 } else {
303315 stream->parallel_for (nd_range, [=](sycl::nd_item<3 > item_ct1) {
304316 rope_multi<T, true >(x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor,
305- corr_dims, theta_scale, freq_factors, sections, item_ct1);
317+ corr_dims, theta_scale, freq_factors, sections, is_imrope, item_ct1);
306318 });
307319 }
308320}
@@ -381,6 +393,7 @@ inline void ggml_sycl_op_rope(ggml_backend_sycl_context & ctx, ggml_tensor *dst)
381393
382394 const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
383395 const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
396+ const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE;
384397 const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
385398
386399 if (is_mrope) {
@@ -422,11 +435,11 @@ inline void ggml_sycl_op_rope(ggml_backend_sycl_context & ctx, ggml_tensor *dst)
422435 if (dst->src [0 ]->type == GGML_TYPE_F16) {
423436 rope_multi_sycl ((const sycl::half *)dst->src [0 ]->data , (sycl::half *)dst->data , ne00, ne01, ne02, s01,
424437 s02, n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
425- freq_factors, sections, main_stream);
438+ freq_factors, sections, is_imrope, main_stream);
426439 } else if (dst->src [0 ]->type == GGML_TYPE_F32) {
427440 rope_multi_sycl ((const float *) dst->src [0 ]->data , (float *) dst->data , ne00, ne01, ne02, s01, s02, n_dims,
428441 nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections,
429- main_stream);
442+ is_imrope, main_stream);
430443 } else {
431444 GGML_ABORT (" Fatal error: Tensor type unsupported!" );
432445 }
0 commit comments