Skip to content

Commit 6a0191a

Browse files
committed
metal : add imrope support
1 parent 0518b0a commit 6a0191a

File tree

4 files changed

+33
-10
lines changed

4 files changed

+33
-10
lines changed

ggml/src/ggml-metal/ggml-metal-device.cpp

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1332,11 +1332,12 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rope(ggml_metal_library_t
13321332

13331333
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
13341334
const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
1335+
const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE;
13351336
const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
13361337

13371338
if (is_neox) {
13381339
snprintf(base, 256, "kernel_rope_neox_%s", ggml_type_name(op->src[0]->type));
1339-
} else if (is_mrope && !is_vision) {
1340+
} else if ((is_mrope || is_imrope) && !is_vision) {
13401341
GGML_ASSERT(op->src[1]->ne[0]*4 >= op->src[0]->ne[2]); // need at least 4 pos per token
13411342
snprintf(base, 256, "kernel_rope_multi_%s", ggml_type_name(op->src[0]->type));
13421343
} else if (is_vision) {
@@ -1346,14 +1347,20 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rope(ggml_metal_library_t
13461347
snprintf(base, 256, "kernel_rope_norm_%s", ggml_type_name(op->src[0]->type));
13471348
}
13481349

1349-
snprintf(name, 256, "%s", base);
1350+
snprintf(name, 256, "%s_imrope=%d", base, is_imrope ? 1 : 0);
13501351

13511352
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
13521353
if (res) {
13531354
return res;
13541355
}
13551356

1356-
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1357+
ggml_metal_cv_t cv = ggml_metal_cv_init();
1358+
1359+
ggml_metal_cv_set_bool(cv, is_imrope, FC_ROPE + 0);
1360+
1361+
res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
1362+
1363+
ggml_metal_cv_free(cv);
13571364

13581365
return res;
13591366
}

ggml/src/ggml-metal/ggml-metal-impl.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@
7676
#define FC_FLASH_ATTN_EXT_VEC_REDUCE 500
7777
#define FC_MUL_MV 600
7878
#define FC_MUL_MM 700
79+
#define FC_ROPE 800
7980

8081
// op-specific constants
8182
#define OP_FLASH_ATTN_EXT_NQPTG 8

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3709,6 +3709,8 @@ template [[host_name("kernel_mul_mv_bf16_f32_short")]] kernel mul_mv_t_t_short_
37093709
template [[host_name("kernel_mul_mv_bf16_bf16_short")]] kernel mul_mv_t_t_short_t kernel_mul_mv_t_t_short<bfloat, bfloat>;
37103710
#endif
37113711

3712+
constant bool FC_rope_is_imrope [[function_constant(FC_ROPE + 0)]];
3713+
37123714
static float rope_yarn_ramp(const float low, const float high, const int i0) {
37133715
const float y = (i0 / 2 - low) / max(0.001f, high - low);
37143716
return 1.0f - min(1.0f, max(0.0f, y));
@@ -3889,14 +3891,26 @@ kernel void kernel_rope_multi(
38893891
const int sector = ic % sect_dims;
38903892

38913893
float theta_base;
3892-
if (sector < args.sect_0) {
3893-
theta_base = (float) pos[i2];
3894-
} else if (sector < sec_w01) {
3895-
theta_base = (float) pos[i2 + args.ne02];
3896-
} else if (sector < sec_w012) {
3897-
theta_base = (float) pos[i2 + args.ne02 * 2];
3894+
if (FC_rope_is_imrope) {
3895+
if (sector % 3 == 1 && sector < 3 * args.sect_1) { // h
3896+
theta_base = (float) pos[i2 + args.ne02 * 1];
3897+
} else if (sector % 3 == 2 && sector < 3 * args.sect_2) { // w
3898+
theta_base = (float) pos[i2 + args.ne02 * 2];
3899+
} else if (sector % 3 == 0 && sector < 3 * args.sect_0) { // t
3900+
theta_base = (float) pos[i2 + args.ne02 * 0];
3901+
} else { // e
3902+
theta_base = (float) pos[i2 + args.ne02 * 3];
3903+
}
38983904
} else {
3899-
theta_base = (float) pos[i2 + args.ne02 * 3];
3905+
if (sector < args.sect_0) {
3906+
theta_base = (float) pos[i2];
3907+
} else if (sector < sec_w01) {
3908+
theta_base = (float) pos[i2 + args.ne02 * 1];
3909+
} else if (sector < sec_w012) {
3910+
theta_base = (float) pos[i2 + args.ne02 * 2];
3911+
} else {
3912+
theta_base = (float) pos[i2 + args.ne02 * 3];
3913+
}
39003914
}
39013915
// end of mrope
39023916

tests/test-backend-ops.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7081,6 +7081,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
70817081
test_cases.emplace_back(new test_rope(type, {128, 12, 2, 1}, 20, GGML_ROPE_TYPE_IMROPE, 512, fs, ef, af, ff, v, fw));
70827082
test_cases.emplace_back(new test_rope(type, {128, 28, 2, 1}, 32, GGML_ROPE_TYPE_IMROPE, 512, fs, ef, af, ff, v, fw));
70837083
test_cases.emplace_back(new test_rope(type, { 80, 16, 2, 1}, 80, GGML_ROPE_TYPE_VISION, 512, fs, ef, af, ff, v, fw)); // rope_multi,m-rope (qwen2vl ViT)
7084+
test_cases.emplace_back(new test_rope(type, {128, 16, 2, 1}, 128, GGML_ROPE_TYPE_IMROPE, 512, fs, ef, af, ff, v, fw)); // rope_multi,m-rope (qwen3vl)
70847085
}
70857086

70867087
test_cases.emplace_back(new test_rope(type, { 64, 128, 2, 1}, 64, GGML_ROPE_TYPE_NEOX, 512, fs, ef, af, ff, v, fw)); // neox (falcon 40B)

0 commit comments

Comments
 (0)