@@ -5543,6 +5543,24 @@ static void ggml_mrope_cache_init(
55435543 }
55445544}
55455545
5546+ static void rotate_pairs (const int64_t n, const int64_t n_offset, const float * cache, const float * src_data, float * dst_data, const int scale = 2 ) {
5547+ for (int64_t i0 = 0 ; i0 < n; i0 += 2 ) {
5548+ const int64_t ic = i0/scale; // hack for GGML_ROPE_TYPE_NORMAL, where we need ic = i0; for all other cases, ic = i0/2
5549+
5550+ const float cos_theta = cache[i0 + 0 ];
5551+ const float sin_theta = cache[i0 + 1 ];
5552+
5553+ const float * const src = src_data + ic;
5554+ float * dst = dst_data + ic;
5555+
5556+ const float x0 = src[0 ];
5557+ const float x1 = src[n_offset];
5558+
5559+ dst[0 ] = x0*cos_theta - x1*sin_theta;
5560+ dst[n_offset] = x0*sin_theta + x1*cos_theta;
5561+ }
5562+ }
5563+
55465564static void ggml_compute_forward_rope_f32 (
55475565 const ggml_compute_params * params,
55485566 ggml_tensor * dst,
@@ -5599,12 +5617,11 @@ static void ggml_compute_forward_rope_f32(
55995617 float corr_dims[2 ];
56005618 ggml_rope_yarn_corr_dims (n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
56015619
5602- const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
5603- const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE; // ggml_rope_multi, multimodal rotary position embedding
56045620 const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE; // qwen3vl apply interleaved mrope
5621+ const bool mrope_used = mode & GGML_ROPE_TYPE_MROPE; // ggml_rope_multi, note: also true for vision (24 & 8 == true)
56055622 const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
56065623
5607- if (is_mrope ) {
5624+ if (mrope_used ) {
56085625 GGML_ASSERT (sections[0 ] > 0 || sections[1 ] > 0 || sections[2 ] > 0 );
56095626 }
56105627
@@ -5630,7 +5647,7 @@ static void ggml_compute_forward_rope_f32(
56305647 for (int64_t i2 = 0 ; i2 < ne2; i2++) { // seq-len
56315648
56325649 float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith;
5633- if (!is_mrope ) {
5650+ if (!mrope_used ) {
56345651 const int64_t p = pos[i2];
56355652 ggml_rope_cache_init (p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
56365653 }
@@ -5648,73 +5665,26 @@ static void ggml_compute_forward_rope_f32(
56485665 if (ir++ < ir0) continue ;
56495666 if (ir > ir1) break ;
56505667
5651- if (is_neox || is_mrope) {
5652- if (is_vision){
5653- for (int64_t i0 = 0 ; i0 < n_dims; i0 += 2 ) {
5654- const int64_t ic = i0/2 ;
5655-
5656- const float cos_theta = cache[i0 + 0 ];
5657- const float sin_theta = cache[i0 + 1 ];
5658-
5659- const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
5660- float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
5661-
5662- const float x0 = src[0 ];
5663- const float x1 = src[n_dims];
5664-
5665- dst_data[0 ] = x0*cos_theta - x1*sin_theta;
5666- dst_data[n_dims] = x0*sin_theta + x1*cos_theta;
5667- }
5668- } else {
5669- for (int64_t i0 = 0 ; i0 < n_dims; i0 += 2 ) {
5670- const int64_t ic = i0/2 ;
5671-
5672- const float cos_theta = cache[i0 + 0 ];
5673- const float sin_theta = cache[i0 + 1 ];
5674-
5675- const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
5676- float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
5677-
5678- const float x0 = src[0 ];
5679- const float x1 = src[n_dims/2 ];
5668+ float * src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
5669+ float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1);
56805670
5681- dst_data[0 ] = x0*cos_theta - x1*sin_theta;
5682- dst_data[n_dims/2 ] = x0*sin_theta + x1*cos_theta;
5683- }
5684- }
5685- } else {
5686- for (int64_t i0 = 0 ; i0 < n_dims; i0 += 2 ) {
5687- const float cos_theta = cache[i0 + 0 ];
5688- const float sin_theta = cache[i0 + 1 ];
5689-
5690- const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
5691- float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
5692-
5693- const float x0 = src[0 ];
5694- const float x1 = src[1 ];
5695-
5696- dst_data[0 ] = x0*cos_theta - x1*sin_theta;
5697- dst_data[1 ] = x0*sin_theta + x1*cos_theta;
5698- }
5671+ switch (mode) {
5672+ case GGML_ROPE_TYPE_NORMAL:
5673+ rotate_pairs (n_dims, 1 , cache, src, dst_data, 1 );
5674+ break ;
5675+ case GGML_ROPE_TYPE_NEOX:
5676+ case GGML_ROPE_TYPE_MROPE: // pure, not vision
5677+ rotate_pairs (n_dims, n_dims/2 , cache, src, dst_data);
5678+ break ;
5679+ case GGML_ROPE_TYPE_VISION:
5680+ rotate_pairs (ne0, n_dims, cache, src, dst_data);
5681+ break ;
5682+ default :
5683+ // rope type not supported, silently default to NORMAL
5684+ rotate_pairs (n_dims, 1 , cache, src, dst_data, 1 );
56995685 }
57005686
5701- if (is_vision) {
5702- for (int64_t i0 = n_dims; i0 < ne0; i0 += 2 ) {
5703- const int64_t ic = i0/2 ;
5704-
5705- const float cos_theta = cache[i0 + 0 ];
5706- const float sin_theta = cache[i0 + 1 ];
5707-
5708- const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
5709- float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
5710-
5711- const float x0 = src[0 ];
5712- const float x1 = src[n_dims];
5713-
5714- dst_data[0 ] = x0*cos_theta - x1*sin_theta;
5715- dst_data[n_dims] = x0*sin_theta + x1*cos_theta;
5716- }
5717- } else {
5687+ if (!is_vision) {
57185688 // fill the remain channels with data from src tensor
57195689 for (int64_t i0 = n_dims; i0 < ne0; i0 += 2 ) {
57205690 const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
@@ -5724,7 +5694,7 @@ static void ggml_compute_forward_rope_f32(
57245694 dst_data[1 ] = src[1 ];
57255695 }
57265696 }
5727- }
5697+ } // attn-heads
57285698 }
57295699 }
57305700}
0 commit comments