@@ -125,7 +125,7 @@ template<bool forward, bool has_ff, typename T>
125125static  __global__  void  rope_multi (
126126        const  T * x, T * dst, const  int  ne0, const  int  ne1, const  int  ne2, const  int  s1, const  int  s2,
127127        const  int  n_dims, const  int32_t  * pos, const  float  freq_scale, const  float  ext_factor, const  float  attn_factor,
128-         const  rope_corr_dims corr_dims, const  float  theta_scale, const  float  * freq_factors, const  mrope_sections sections) {
128+         const  rope_corr_dims corr_dims, const  float  theta_scale, const  float  * freq_factors, const  mrope_sections sections,  const   bool  is_imrope ) {
129129    const  int  i0 = 2 *(blockDim .y *blockIdx .y  + threadIdx .y );
130130
131131    if  (i0 >= ne0) {
@@ -152,17 +152,27 @@ static __global__ void rope_multi(
152152    const  int  sector = (i0 / 2 ) % sect_dims;
153153
154154    float  theta_base = 0.0 ;
155-     if  (sector < sections.v [0 ]) {
156-         theta_base = pos[channel_x]*powf (theta_scale, i0/2 .0f );
157-     }
158-     else  if  (sector >= sections.v [0 ] && sector < sec_w) {
159-         theta_base = pos[channel_x + ne2 * 1 ]*powf (theta_scale, i0/2 .0f );
160-     }
161-     else  if  (sector >= sec_w && sector < sec_w + sections.v [2 ]) {
162-         theta_base = pos[channel_x + ne2 * 2 ]*powf (theta_scale, i0/2 .0f );
163-     }
164-     else  if  (sector >= sec_w + sections.v [2 ]) {
165-         theta_base = pos[channel_x + ne2 * 3 ]*powf (theta_scale, i0/2 .0f );
155+     if  (is_imrope) {
156+         if  (sector % 3  == 1  && sector < 3  * sections.v [1 ]) { //  h
157+             theta_base = pos[channel_x + ne2 * 1 ]*powf (theta_scale, i0/2 .0f );
158+         } else  if  (sector % 3  == 2  && sector < 3  * sections.v [2 ]) { //  w
159+             theta_base = pos[channel_x + ne2 * 2 ]*powf (theta_scale, i0/2 .0f );
160+         } else  if  (sector % 3  == 0  && sector < 3  * sections.v [0 ]) { //  t
161+             theta_base = pos[channel_x]*powf (theta_scale, i0/2 .0f );
162+         }
163+     } else  {
164+         if  (sector < sections.v [0 ]) {
165+             theta_base = pos[channel_x]*powf (theta_scale, i0/2 .0f );
166+         }
167+         else  if  (sector >= sections.v [0 ] && sector < sec_w) {
168+             theta_base = pos[channel_x + ne2 * 1 ]*powf (theta_scale, i0/2 .0f );
169+         }
170+         else  if  (sector >= sec_w && sector < sec_w + sections.v [2 ]) {
171+             theta_base = pos[channel_x + ne2 * 2 ]*powf (theta_scale, i0/2 .0f );
172+         }
173+         else  if  (sector >= sec_w + sections.v [2 ]) {
174+             theta_base = pos[channel_x + ne2 * 3 ]*powf (theta_scale, i0/2 .0f );
175+         }
166176    }
167177
168178    const  float  freq_factor = has_ff ? freq_factors[i0/2 ] : 1 .0f ;
@@ -276,7 +286,7 @@ template<bool forward, typename T>
276286static  void  rope_multi_cuda (
277287        const  T * x, T * dst, const  int  ne0, const  int  ne1, const  int  ne2, const  int  s1, const  int  s2, const  int  n_dims, const  int  nr,
278288        const  int32_t  * pos, const  float  freq_scale, const  float  freq_base, const  float  ext_factor, const  float  attn_factor,
279-         const  rope_corr_dims corr_dims, const  float  * freq_factors, const  mrope_sections sections, cudaStream_t stream) {
289+         const  rope_corr_dims corr_dims, const  float  * freq_factors, const  mrope_sections sections, const   bool  is_imrope,  cudaStream_t stream) {
280290    GGML_ASSERT (ne0 % 2  == 0 );
281291    const  dim3  block_dims (1 , CUDA_ROPE_BLOCK_SIZE, 1 );
282292    const  int  n_blocks_x = (ne0 + 2 *CUDA_ROPE_BLOCK_SIZE - 1 ) / (2 *CUDA_ROPE_BLOCK_SIZE);
@@ -287,11 +297,11 @@ static void rope_multi_cuda(
287297    if  (freq_factors == nullptr ) {
288298        rope_multi<forward, false , T><<<block_nums, block_dims, 0 , stream>>> (
289299            x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor,
290-             attn_factor, corr_dims, theta_scale, freq_factors, sections);
300+             attn_factor, corr_dims, theta_scale, freq_factors, sections, is_imrope );
291301    } else  {
292302        rope_multi<forward, true , T><<<block_nums, block_dims, 0 , stream>>> (
293303            x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor,
294-             attn_factor, corr_dims, theta_scale, freq_factors, sections);
304+             attn_factor, corr_dims, theta_scale, freq_factors, sections, is_imrope );
295305    }
296306}
297307
@@ -369,6 +379,7 @@ void ggml_cuda_op_rope_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
369379
370380    const  bool  is_neox = mode & GGML_ROPE_TYPE_NEOX;
371381    const  bool  is_mrope = mode & GGML_ROPE_TYPE_MROPE;
382+     const  bool  is_imrope = mode == GGML_ROPE_TYPE_IMROPE;
372383    const  bool  is_vision = mode == GGML_ROPE_TYPE_VISION;
373384
374385    if  (is_mrope) {
@@ -406,11 +417,11 @@ void ggml_cuda_op_rope_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
406417        if  (src0->type  == GGML_TYPE_F32) {
407418            rope_multi_cuda<forward>(
408419                (const  float  *) src0_d, (float  *) dst_d, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale,
409-                 freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream);
420+                 freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, is_imrope,  stream);
410421        } else  if  (src0->type  == GGML_TYPE_F16) {
411422            rope_multi_cuda<forward>(
412423                (const  half *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale,
413-                 freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream);
424+                 freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, is_imrope,  stream);
414425        } else  {
415426            GGML_ABORT (" fatal error"  );
416427        }
0 commit comments