Skip to content

Commit c24f4e2

Browse files
foldlggerganov
andauthored
ggml : update ggml_rope_multi (#12665)
* update `rope_multi`: 1. add `ggml_rope_multi_inplace`; 1. use `GGML_MROPE_SECTIONS` instead of 4. * Apply suggestions from code review Co-authored-by: Georgi Gerganov <[email protected]> --------- Co-authored-by: Georgi Gerganov <[email protected]>
1 parent d8914fc commit c24f4e2

File tree

2 files changed

+63
-41
lines changed

2 files changed

+63
-41
lines changed

ggml/include/ggml.h

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,8 @@
241241
#define GGML_ROPE_TYPE_MROPE 8
242242
#define GGML_ROPE_TYPE_VISION 24
243243

244+
#define GGML_MROPE_SECTIONS 4
245+
244246
#define GGML_UNUSED(x) (void)(x)
245247

246248
#define GGML_PAD(x, n) (((x) + (n) - 1) & ~((n) - 1))
@@ -1660,7 +1662,7 @@ extern "C" {
16601662
struct ggml_tensor * b,
16611663
struct ggml_tensor * c,
16621664
int n_dims,
1663-
int sections[4],
1665+
int sections[GGML_MROPE_SECTIONS],
16641666
int mode,
16651667
int n_ctx_orig,
16661668
float freq_base,
@@ -1686,6 +1688,22 @@ extern "C" {
16861688
float beta_fast,
16871689
float beta_slow);
16881690

1691+
GGML_API struct ggml_tensor * ggml_rope_multi_inplace(
1692+
struct ggml_context * ctx,
1693+
struct ggml_tensor * a,
1694+
struct ggml_tensor * b,
1695+
struct ggml_tensor * c,
1696+
int n_dims,
1697+
int sections[GGML_MROPE_SECTIONS],
1698+
int mode,
1699+
int n_ctx_orig,
1700+
float freq_base,
1701+
float freq_scale,
1702+
float ext_factor,
1703+
float attn_factor,
1704+
float beta_fast,
1705+
float beta_slow);
1706+
16891707
GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_rope_custom(
16901708
struct ggml_context * ctx,
16911709
struct ggml_tensor * a,

ggml/src/ggml.c

Lines changed: 44 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -3885,6 +3885,7 @@ static struct ggml_tensor * ggml_rope_impl(
38853885
struct ggml_tensor * b,
38863886
struct ggml_tensor * c,
38873887
int n_dims,
3888+
int sections[GGML_MROPE_SECTIONS],
38883889
int mode,
38893890
int n_ctx_orig,
38903891
float freq_base,
@@ -3898,15 +3899,19 @@ static struct ggml_tensor * ggml_rope_impl(
38983899

38993900
GGML_ASSERT(ggml_is_vector(b));
39003901
GGML_ASSERT(b->type == GGML_TYPE_I32);
3901-
GGML_ASSERT(a->ne[2] == b->ne[0]);
3902+
3903+
bool mrope_used = mode & GGML_ROPE_TYPE_MROPE;
3904+
if (mrope_used) {
3905+
GGML_ASSERT(a->ne[2] * 4 == b->ne[0]); // mrope expecting 4 position ids per token
3906+
} else {
3907+
GGML_ASSERT(a->ne[2] == b->ne[0]);
3908+
}
39023909

39033910
if (c) {
39043911
GGML_ASSERT(c->type == GGML_TYPE_F32);
39053912
GGML_ASSERT(c->ne[0] >= n_dims / 2);
39063913
}
39073914

3908-
int sections[4] = {0, 0, 0, 0};
3909-
39103915
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
39113916

39123917
int32_t params[15] = { /*n_past*/ 0, n_dims, mode, /*n_ctx*/ 0, n_ctx_orig };
@@ -3916,7 +3921,11 @@ static struct ggml_tensor * ggml_rope_impl(
39163921
memcpy(params + 8, &attn_factor, sizeof(float));
39173922
memcpy(params + 9, &beta_fast, sizeof(float));
39183923
memcpy(params + 10, &beta_slow, sizeof(float));
3919-
memcpy(params + 11, &sections, sizeof(int)*4);
3924+
if (mrope_used) {
3925+
memcpy(params + 11, sections, sizeof(int32_t) * GGML_MROPE_SECTIONS);
3926+
} else {
3927+
memset(params + 11, 0, sizeof(int32_t) * GGML_MROPE_SECTIONS);
3928+
}
39203929
ggml_set_op_params(result, params, sizeof(params));
39213930

39223931
result->op = GGML_OP_ROPE;
@@ -3934,7 +3943,7 @@ struct ggml_tensor * ggml_rope(
39343943
int n_dims,
39353944
int mode) {
39363945
return ggml_rope_impl(
3937-
ctx, a, b, NULL, n_dims, mode, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, false
3946+
ctx, a, b, NULL, n_dims, NULL, mode, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, false
39383947
);
39393948
}
39403949

@@ -3944,7 +3953,7 @@ struct ggml_tensor * ggml_rope_multi(
39443953
struct ggml_tensor * b,
39453954
struct ggml_tensor * c,
39463955
int n_dims,
3947-
int sections[4],
3956+
int sections[GGML_MROPE_SECTIONS],
39483957
int mode,
39493958
int n_ctx_orig,
39503959
float freq_base,
@@ -3953,36 +3962,31 @@ struct ggml_tensor * ggml_rope_multi(
39533962
float attn_factor,
39543963
float beta_fast,
39553964
float beta_slow) {
3956-
// Multimodal Rotary Position Embedding
3957-
GGML_ASSERT((mode & 1) == 0 && "mode & 1 == 1 is no longer supported");
3958-
3959-
GGML_ASSERT(ggml_is_vector(b));
3960-
GGML_ASSERT(b->type == GGML_TYPE_I32);
3961-
GGML_ASSERT(a->ne[2] * 4 == b->ne[0]); // mrope expecting 4 position ids per token
3962-
3963-
if (c) {
3964-
GGML_ASSERT(c->type == GGML_TYPE_F32);
3965-
GGML_ASSERT(c->ne[0] >= n_dims / 2);
3966-
}
3967-
3968-
struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
3969-
3970-
int32_t params[11 + 4] = { /*n_past*/ 0, n_dims, mode, /*n_ctx*/ 0, n_ctx_orig };
3971-
memcpy(params + 5, &freq_base, sizeof(float));
3972-
memcpy(params + 6, &freq_scale, sizeof(float));
3973-
memcpy(params + 7, &ext_factor, sizeof(float));
3974-
memcpy(params + 8, &attn_factor, sizeof(float));
3975-
memcpy(params + 9, &beta_fast, sizeof(float));
3976-
memcpy(params + 10, &beta_slow, sizeof(float));
3977-
memcpy(&params[11], sections, sizeof(int)*4);
3978-
ggml_set_op_params(result, params, sizeof(params));
3979-
3980-
result->op = GGML_OP_ROPE;
3981-
result->src[0] = a;
3982-
result->src[1] = b;
3983-
result->src[2] = c;
3965+
return ggml_rope_impl(
3966+
ctx, a, b, c, n_dims, sections, mode, n_ctx_orig, freq_base, freq_scale,
3967+
ext_factor, attn_factor, beta_fast, beta_slow, false
3968+
);
3969+
}
39843970

3985-
return result;
3971+
struct ggml_tensor * ggml_rope_multi_inplace(
3972+
struct ggml_context * ctx,
3973+
struct ggml_tensor * a,
3974+
struct ggml_tensor * b,
3975+
struct ggml_tensor * c,
3976+
int n_dims,
3977+
int sections[GGML_MROPE_SECTIONS],
3978+
int mode,
3979+
int n_ctx_orig,
3980+
float freq_base,
3981+
float freq_scale,
3982+
float ext_factor,
3983+
float attn_factor,
3984+
float beta_fast,
3985+
float beta_slow) {
3986+
return ggml_rope_impl(
3987+
ctx, a, b, c, n_dims, sections, mode, n_ctx_orig, freq_base, freq_scale,
3988+
ext_factor, attn_factor, beta_fast, beta_slow, true
3989+
);
39863990
}
39873991

39883992
struct ggml_tensor * ggml_rope_inplace(
@@ -3992,7 +3996,7 @@ struct ggml_tensor * ggml_rope_inplace(
39923996
int n_dims,
39933997
int mode) {
39943998
return ggml_rope_impl(
3995-
ctx, a, b, NULL, n_dims, mode, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, true
3999+
ctx, a, b, NULL, n_dims, NULL, mode, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, true
39964000
);
39974001
}
39984002

@@ -4011,7 +4015,7 @@ struct ggml_tensor * ggml_rope_ext(
40114015
float beta_fast,
40124016
float beta_slow) {
40134017
return ggml_rope_impl(
4014-
ctx, a, b, c, n_dims, mode, n_ctx_orig, freq_base, freq_scale,
4018+
ctx, a, b, c, n_dims, NULL, mode, n_ctx_orig, freq_base, freq_scale,
40154019
ext_factor, attn_factor, beta_fast, beta_slow, false
40164020
);
40174021
}
@@ -4031,7 +4035,7 @@ struct ggml_tensor * ggml_rope_ext_inplace(
40314035
float beta_fast,
40324036
float beta_slow) {
40334037
return ggml_rope_impl(
4034-
ctx, a, b, c, n_dims, mode, n_ctx_orig, freq_base, freq_scale,
4038+
ctx, a, b, c, n_dims, NULL, mode, n_ctx_orig, freq_base, freq_scale,
40354039
ext_factor, attn_factor, beta_fast, beta_slow, true
40364040
);
40374041
}
@@ -4050,7 +4054,7 @@ struct ggml_tensor * ggml_rope_custom(
40504054
float beta_fast,
40514055
float beta_slow) {
40524056
return ggml_rope_impl(
4053-
ctx, a, b, NULL, n_dims, mode, n_ctx_orig, freq_base, freq_scale,
4057+
ctx, a, b, NULL, n_dims, NULL, mode, n_ctx_orig, freq_base, freq_scale,
40544058
ext_factor, attn_factor, beta_fast, beta_slow, false
40554059
);
40564060
}
@@ -4069,7 +4073,7 @@ struct ggml_tensor * ggml_rope_custom_inplace(
40694073
float beta_fast,
40704074
float beta_slow) {
40714075
return ggml_rope_impl(
4072-
ctx, a, b, NULL, n_dims, mode, n_ctx_orig, freq_base, freq_scale,
4076+
ctx, a, b, NULL, n_dims, NULL, mode, n_ctx_orig, freq_base, freq_scale,
40734077
ext_factor, attn_factor, beta_fast, beta_slow, true
40744078
);
40754079
}

0 commit comments

Comments
 (0)