@@ -5531,6 +5531,24 @@ static void ggml_mrope_cache_init(
55315531 }
55325532}
55335533
5534+ static void rotate_pairs (const int64_t n, const int64_t n_offset, const float * cache, const float * src_data, float * dst_data, const int scale = 2 ) {
5535+ for (int64_t i0 = 0 ; i0 < n; i0 += 2 ) {
5536+ const int64_t ic = i0/scale; // hack for GGML_ROPE_TYPE_NORMAL, where we need ic = i0; for all other cases, ic = i0/2
5537+
5538+ const float cos_theta = cache[i0 + 0 ];
5539+ const float sin_theta = cache[i0 + 1 ];
5540+
5541+ const float * const src = src_data + ic;
5542+ float * dst = dst_data + ic;
5543+
5544+ const float x0 = src[0 ];
5545+ const float x1 = src[n_offset];
5546+
5547+ dst[0 ] = x0*cos_theta - x1*sin_theta;
5548+ dst[n_offset] = x0*sin_theta + x1*cos_theta;
5549+ }
5550+ }
5551+
55345552static void ggml_compute_forward_rope_f32 (
55355553 const ggml_compute_params * params,
55365554 ggml_tensor * dst,
@@ -5587,11 +5605,10 @@ static void ggml_compute_forward_rope_f32(
55875605 float corr_dims[2 ];
55885606 ggml_rope_yarn_corr_dims (n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
55895607
5590- const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
5591- const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE; // ggml_rope_multi, multimodal rotary position embedding
5608+ const bool mrope_used = mode & GGML_ROPE_TYPE_MROPE; // note: also true for vision (24 & 8 == true)
55925609 const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
55935610
5594- if (is_mrope ) {
5611+ if (mrope_used ) {
55955612 GGML_ASSERT (sections[0 ] > 0 || sections[1 ] > 0 || sections[2 ] > 0 );
55965613 }
55975614
@@ -5617,7 +5634,7 @@ static void ggml_compute_forward_rope_f32(
56175634 for (int64_t i2 = 0 ; i2 < ne2; i2++) { // seq-len
56185635
56195636 float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith;
5620- if (!is_mrope ) {
5637+ if (!mrope_used ) {
56215638 const int64_t p = pos[i2];
56225639 ggml_rope_cache_init (p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
56235640 }
@@ -5635,73 +5652,26 @@ static void ggml_compute_forward_rope_f32(
56355652 if (ir++ < ir0) continue ;
56365653 if (ir > ir1) break ;
56375654
5638- if (is_neox || is_mrope) {
5639- if (is_vision){
5640- for (int64_t i0 = 0 ; i0 < n_dims; i0 += 2 ) {
5641- const int64_t ic = i0/2 ;
5642-
5643- const float cos_theta = cache[i0 + 0 ];
5644- const float sin_theta = cache[i0 + 1 ];
5645-
5646- const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
5647- float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
5648-
5649- const float x0 = src[0 ];
5650- const float x1 = src[n_dims];
5651-
5652- dst_data[0 ] = x0*cos_theta - x1*sin_theta;
5653- dst_data[n_dims] = x0*sin_theta + x1*cos_theta;
5654- }
5655- } else {
5656- for (int64_t i0 = 0 ; i0 < n_dims; i0 += 2 ) {
5657- const int64_t ic = i0/2 ;
5658-
5659- const float cos_theta = cache[i0 + 0 ];
5660- const float sin_theta = cache[i0 + 1 ];
5661-
5662- const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
5663- float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
5664-
5665- const float x0 = src[0 ];
5666- const float x1 = src[n_dims/2 ];
5655+ float * src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
5656+ float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1);
56675657
5668- dst_data[0 ] = x0*cos_theta - x1*sin_theta;
5669- dst_data[n_dims/2 ] = x0*sin_theta + x1*cos_theta;
5670- }
5671- }
5672- } else {
5673- for (int64_t i0 = 0 ; i0 < n_dims; i0 += 2 ) {
5674- const float cos_theta = cache[i0 + 0 ];
5675- const float sin_theta = cache[i0 + 1 ];
5676-
5677- const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
5678- float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
5679-
5680- const float x0 = src[0 ];
5681- const float x1 = src[1 ];
5682-
5683- dst_data[0 ] = x0*cos_theta - x1*sin_theta;
5684- dst_data[1 ] = x0*sin_theta + x1*cos_theta;
5685- }
5658+ switch (mode) {
5659+ case GGML_ROPE_TYPE_NORMAL:
5660+ rotate_pairs (n_dims, 1 , cache, src, dst_data, 1 );
5661+ break ;
5662+ case GGML_ROPE_TYPE_NEOX:
5663+ case GGML_ROPE_TYPE_MROPE: // pure, not vision
5664+ rotate_pairs (n_dims, n_dims/2 , cache, src, dst_data);
5665+ break ;
5666+ case GGML_ROPE_TYPE_VISION:
5667+ rotate_pairs (ne0, n_dims, cache, src, dst_data);
5668+ break ;
5669+ default :
5670+ // rope type not supported, silently default to NORMAL
5671+ rotate_pairs (n_dims, 1 , cache, src, dst_data, 1 );
56865672 }
56875673
5688- if (is_vision) {
5689- for (int64_t i0 = n_dims; i0 < ne0; i0 += 2 ) {
5690- const int64_t ic = i0/2 ;
5691-
5692- const float cos_theta = cache[i0 + 0 ];
5693- const float sin_theta = cache[i0 + 1 ];
5694-
5695- const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
5696- float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
5697-
5698- const float x0 = src[0 ];
5699- const float x1 = src[n_dims];
5700-
5701- dst_data[0 ] = x0*cos_theta - x1*sin_theta;
5702- dst_data[n_dims] = x0*sin_theta + x1*cos_theta;
5703- }
5704- } else {
5674+ if (!is_vision) {
57055675 // fill the remain channels with data from src tensor
57065676 for (int64_t i0 = n_dims; i0 < ne0; i0 += 2 ) {
57075677 const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
@@ -5711,7 +5681,7 @@ static void ggml_compute_forward_rope_f32(
57115681 dst_data[1 ] = src[1 ];
57125682 }
57135683 }
5714- }
5684+ } // attn-heads
57155685 }
57165686 }
57175687}
0 commit comments