Skip to content

Commit 4b540ab

Browse files
committed
extract rotate_pairs logic from ggml_compute_forward_rope_f32
1 parent 1c1409e commit 4b540ab

File tree

2 files changed

+40
-69
lines changed

2 files changed

+40
-69
lines changed

ggml/src/ggml-cpu/ops.cpp

Lines changed: 39 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -5531,6 +5531,24 @@ static void ggml_mrope_cache_init(
55315531
}
55325532
}
55335533

5534+
static void rotate_pairs(const int64_t n, const int64_t n_offset, const float * cache, const float * src_data, float * dst_data, const int scale = 2) {
5535+
for (int64_t i0 = 0; i0 < n; i0 += 2) {
5536+
const int64_t ic = i0/scale; //hack for GGML_ROPE_TYPE_NORMAL, where we need ic = i0; for all other cases, ic = i0/2
5537+
5538+
const float cos_theta = cache[i0 + 0];
5539+
const float sin_theta = cache[i0 + 1];
5540+
5541+
const float * const src = src_data + ic;
5542+
float * dst = dst_data + ic;
5543+
5544+
const float x0 = src[0];
5545+
const float x1 = src[n_offset];
5546+
5547+
dst[0] = x0*cos_theta - x1*sin_theta;
5548+
dst[n_offset] = x0*sin_theta + x1*cos_theta;
5549+
}
5550+
}
5551+
55345552
static void ggml_compute_forward_rope_f32(
55355553
const ggml_compute_params * params,
55365554
ggml_tensor * dst,
@@ -5587,11 +5605,10 @@ static void ggml_compute_forward_rope_f32(
55875605
float corr_dims[2];
55885606
ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
55895607

5590-
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
5591-
const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE; // ggml_rope_multi, multimodal rotary position embedding
5608+
const bool mrope_used = mode & GGML_ROPE_TYPE_MROPE; // note: also true for vision (24 & 8 == true)
55925609
const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
55935610

5594-
if (is_mrope) {
5611+
if (mrope_used) {
55955612
GGML_ASSERT(sections[0] > 0 || sections[1] > 0 || sections[2] > 0);
55965613
}
55975614

@@ -5617,7 +5634,7 @@ static void ggml_compute_forward_rope_f32(
56175634
for (int64_t i2 = 0; i2 < ne2; i2++) { // seq-len
56185635

56195636
float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith;
5620-
if (!is_mrope) {
5637+
if (!mrope_used) {
56215638
const int64_t p = pos[i2];
56225639
ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
56235640
}
@@ -5635,73 +5652,26 @@ static void ggml_compute_forward_rope_f32(
56355652
if (ir++ < ir0) continue;
56365653
if (ir > ir1) break;
56375654

5638-
if (is_neox || is_mrope) {
5639-
if (is_vision){
5640-
for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
5641-
const int64_t ic = i0/2;
5642-
5643-
const float cos_theta = cache[i0 + 0];
5644-
const float sin_theta = cache[i0 + 1];
5645-
5646-
const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
5647-
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
5648-
5649-
const float x0 = src[0];
5650-
const float x1 = src[n_dims];
5651-
5652-
dst_data[0] = x0*cos_theta - x1*sin_theta;
5653-
dst_data[n_dims] = x0*sin_theta + x1*cos_theta;
5654-
}
5655-
} else {
5656-
for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
5657-
const int64_t ic = i0/2;
5658-
5659-
const float cos_theta = cache[i0 + 0];
5660-
const float sin_theta = cache[i0 + 1];
5661-
5662-
const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
5663-
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
5664-
5665-
const float x0 = src[0];
5666-
const float x1 = src[n_dims/2];
5655+
float * src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
5656+
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1);
56675657

5668-
dst_data[0] = x0*cos_theta - x1*sin_theta;
5669-
dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
5670-
}
5671-
}
5672-
} else {
5673-
for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
5674-
const float cos_theta = cache[i0 + 0];
5675-
const float sin_theta = cache[i0 + 1];
5676-
5677-
const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
5678-
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
5679-
5680-
const float x0 = src[0];
5681-
const float x1 = src[1];
5682-
5683-
dst_data[0] = x0*cos_theta - x1*sin_theta;
5684-
dst_data[1] = x0*sin_theta + x1*cos_theta;
5685-
}
5658+
switch (mode) {
5659+
case GGML_ROPE_TYPE_NORMAL:
5660+
rotate_pairs(n_dims, 1, cache, src, dst_data, 1);
5661+
break;
5662+
case GGML_ROPE_TYPE_NEOX:
5663+
case GGML_ROPE_TYPE_MROPE: //pure, not vision
5664+
rotate_pairs(n_dims, n_dims/2, cache, src, dst_data);
5665+
break;
5666+
case GGML_ROPE_TYPE_VISION:
5667+
rotate_pairs(ne0, n_dims, cache, src, dst_data);
5668+
break;
5669+
default:
5670+
//rope type not supported, silently default to NORMAL
5671+
rotate_pairs(n_dims, 1, cache, src, dst_data, 1);
56865672
}
56875673

5688-
if (is_vision) {
5689-
for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
5690-
const int64_t ic = i0/2;
5691-
5692-
const float cos_theta = cache[i0 + 0];
5693-
const float sin_theta = cache[i0 + 1];
5694-
5695-
const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
5696-
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
5697-
5698-
const float x0 = src[0];
5699-
const float x1 = src[n_dims];
5700-
5701-
dst_data[0] = x0*cos_theta - x1*sin_theta;
5702-
dst_data[n_dims] = x0*sin_theta + x1*cos_theta;
5703-
}
5704-
} else {
5674+
if (!is_vision) {
57055675
// fill the remain channels with data from src tensor
57065676
for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
57075677
const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
@@ -5711,7 +5681,7 @@ static void ggml_compute_forward_rope_f32(
57115681
dst_data[1] = src[1];
57125682
}
57135683
}
5714-
}
5684+
} //attn-heads
57155685
}
57165686
}
57175687
}

tests/test-rope.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,7 @@ int main(int /*argc*/, const char ** /*argv*/) {
164164
((int32_t *) p2->data)[i] = n_past_2 + i;
165165
}
166166
// test mode 0, 2, 4 (standard, GPT-NeoX, GLM)
167+
// note: GLM is not implemented, it will default to standard
167168
mode = m == 0 ? 0 : m == 1 ? 2 : 4;
168169

169170
// 100, 101, 102, ..., 172

0 commit comments

Comments
 (0)