Skip to content

Commit 97b0baf

Browse files
committed
extract rotate_pairs logic from ggml_compute_forward_rope_f32
1 parent 31c511a commit 97b0baf

File tree

2 files changed

+40
-69
lines changed

2 files changed

+40
-69
lines changed

ggml/src/ggml-cpu/ops.cpp

Lines changed: 39 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -5543,6 +5543,24 @@ static void ggml_mrope_cache_init(
55435543
}
55445544
}
55455545

5546+
static void rotate_pairs(const int64_t n, const int64_t n_offset, const float * cache, const float * src_data, float * dst_data, const int scale = 2) {
5547+
for (int64_t i0 = 0; i0 < n; i0 += 2) {
5548+
const int64_t ic = i0/scale; //hack for GGML_ROPE_TYPE_NORMAL, where we need ic = i0; for all other cases, ic = i0/2
5549+
5550+
const float cos_theta = cache[i0 + 0];
5551+
const float sin_theta = cache[i0 + 1];
5552+
5553+
const float * const src = src_data + ic;
5554+
float * dst = dst_data + ic;
5555+
5556+
const float x0 = src[0];
5557+
const float x1 = src[n_offset];
5558+
5559+
dst[0] = x0*cos_theta - x1*sin_theta;
5560+
dst[n_offset] = x0*sin_theta + x1*cos_theta;
5561+
}
5562+
}
5563+
55465564
static void ggml_compute_forward_rope_f32(
55475565
const ggml_compute_params * params,
55485566
ggml_tensor * dst,
@@ -5599,12 +5617,11 @@ static void ggml_compute_forward_rope_f32(
55995617
float corr_dims[2];
56005618
ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
56015619

5602-
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
5603-
const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE; // ggml_rope_multi, multimodal rotary position embedding
56045620
const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE; // qwen3vl apply interleaved mrope
5621+
const bool mrope_used = mode & GGML_ROPE_TYPE_MROPE; // ggml_rope_multi, note: also true for vision (24 & 8 == true)
56055622
const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
56065623

5607-
if (is_mrope) {
5624+
if (mrope_used) {
56085625
GGML_ASSERT(sections[0] > 0 || sections[1] > 0 || sections[2] > 0);
56095626
}
56105627

@@ -5630,7 +5647,7 @@ static void ggml_compute_forward_rope_f32(
56305647
for (int64_t i2 = 0; i2 < ne2; i2++) { // seq-len
56315648

56325649
float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith;
5633-
if (!is_mrope) {
5650+
if (!mrope_used) {
56345651
const int64_t p = pos[i2];
56355652
ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
56365653
}
@@ -5648,73 +5665,26 @@ static void ggml_compute_forward_rope_f32(
56485665
if (ir++ < ir0) continue;
56495666
if (ir > ir1) break;
56505667

5651-
if (is_neox || is_mrope) {
5652-
if (is_vision){
5653-
for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
5654-
const int64_t ic = i0/2;
5655-
5656-
const float cos_theta = cache[i0 + 0];
5657-
const float sin_theta = cache[i0 + 1];
5658-
5659-
const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
5660-
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
5661-
5662-
const float x0 = src[0];
5663-
const float x1 = src[n_dims];
5664-
5665-
dst_data[0] = x0*cos_theta - x1*sin_theta;
5666-
dst_data[n_dims] = x0*sin_theta + x1*cos_theta;
5667-
}
5668-
} else {
5669-
for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
5670-
const int64_t ic = i0/2;
5671-
5672-
const float cos_theta = cache[i0 + 0];
5673-
const float sin_theta = cache[i0 + 1];
5674-
5675-
const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
5676-
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
5677-
5678-
const float x0 = src[0];
5679-
const float x1 = src[n_dims/2];
5668+
float * src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
5669+
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1);
56805670

5681-
dst_data[0] = x0*cos_theta - x1*sin_theta;
5682-
dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
5683-
}
5684-
}
5685-
} else {
5686-
for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
5687-
const float cos_theta = cache[i0 + 0];
5688-
const float sin_theta = cache[i0 + 1];
5689-
5690-
const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
5691-
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
5692-
5693-
const float x0 = src[0];
5694-
const float x1 = src[1];
5695-
5696-
dst_data[0] = x0*cos_theta - x1*sin_theta;
5697-
dst_data[1] = x0*sin_theta + x1*cos_theta;
5698-
}
5671+
switch (mode) {
5672+
case GGML_ROPE_TYPE_NORMAL:
5673+
rotate_pairs(n_dims, 1, cache, src, dst_data, 1);
5674+
break;
5675+
case GGML_ROPE_TYPE_NEOX:
5676+
case GGML_ROPE_TYPE_MROPE: //pure, not vision
5677+
rotate_pairs(n_dims, n_dims/2, cache, src, dst_data);
5678+
break;
5679+
case GGML_ROPE_TYPE_VISION:
5680+
rotate_pairs(ne0, n_dims, cache, src, dst_data);
5681+
break;
5682+
default:
5683+
//rope type not supported, silently default to NORMAL
5684+
rotate_pairs(n_dims, 1, cache, src, dst_data, 1);
56995685
}
57005686

5701-
if (is_vision) {
5702-
for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
5703-
const int64_t ic = i0/2;
5704-
5705-
const float cos_theta = cache[i0 + 0];
5706-
const float sin_theta = cache[i0 + 1];
5707-
5708-
const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
5709-
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
5710-
5711-
const float x0 = src[0];
5712-
const float x1 = src[n_dims];
5713-
5714-
dst_data[0] = x0*cos_theta - x1*sin_theta;
5715-
dst_data[n_dims] = x0*sin_theta + x1*cos_theta;
5716-
}
5717-
} else {
5687+
if (!is_vision) {
57185688
// fill the remain channels with data from src tensor
57195689
for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
57205690
const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
@@ -5724,7 +5694,7 @@ static void ggml_compute_forward_rope_f32(
57245694
dst_data[1] = src[1];
57255695
}
57265696
}
5727-
}
5697+
} //attn-heads
57285698
}
57295699
}
57305700
}

tests/test-rope.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,7 @@ int main(int /*argc*/, const char ** /*argv*/) {
164164
((int32_t *) p2->data)[i] = n_past_2 + i;
165165
}
166166
// test mode 0, 2, 4 (standard, GPT-NeoX, GLM)
167+
// note: GLM is not implemented, it will default to standard
167168
mode = m == 0 ? 0 : m == 1 ? 2 : 4;
168169

169170
// 100, 101, 102, ..., 172

0 commit comments

Comments
 (0)