Skip to content

Commit d192d9f

Browse files
authored
Merge pull request #35 from JJJYmmm/add_qwen3vl
add imrope metal support + add imrope support for sycl
2 parents 8b1d615 + 10ce7fb commit d192d9f

File tree

5 files changed

+63
-27
lines changed

5 files changed

+63
-27
lines changed

ggml/src/ggml-metal/ggml-metal-device.cpp

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1332,11 +1332,12 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rope(ggml_metal_library_t
13321332

13331333
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
13341334
const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
1335+
const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE;
13351336
const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
13361337

13371338
if (is_neox) {
13381339
snprintf(base, 256, "kernel_rope_neox_%s", ggml_type_name(op->src[0]->type));
1339-
} else if (is_mrope && !is_vision) {
1340+
} else if ((is_mrope || is_imrope) && !is_vision) {
13401341
GGML_ASSERT(op->src[1]->ne[0]*4 >= op->src[0]->ne[2]); // need at least 4 pos per token
13411342
snprintf(base, 256, "kernel_rope_multi_%s", ggml_type_name(op->src[0]->type));
13421343
} else if (is_vision) {
@@ -1346,14 +1347,20 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rope(ggml_metal_library_t
13461347
snprintf(base, 256, "kernel_rope_norm_%s", ggml_type_name(op->src[0]->type));
13471348
}
13481349

1349-
snprintf(name, 256, "%s", base);
1350+
snprintf(name, 256, "%s_imrope=%d", base, is_imrope ? 1 : 0);
13501351

13511352
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
13521353
if (res) {
13531354
return res;
13541355
}
13551356

1356-
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1357+
ggml_metal_cv_t cv = ggml_metal_cv_init();
1358+
1359+
ggml_metal_cv_set_bool(cv, is_imrope, FC_ROPE + 0);
1360+
1361+
res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
1362+
1363+
ggml_metal_cv_free(cv);
13571364

13581365
return res;
13591366
}

ggml/src/ggml-metal/ggml-metal-impl.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@
7676
#define FC_FLASH_ATTN_EXT_VEC_REDUCE 500
7777
#define FC_MUL_MV 600
7878
#define FC_MUL_MM 700
79+
#define FC_ROPE 800
7980

8081
// op-specific constants
8182
#define OP_FLASH_ATTN_EXT_NQPTG 8

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3709,6 +3709,8 @@ template [[host_name("kernel_mul_mv_bf16_f32_short")]] kernel mul_mv_t_t_short_
37093709
template [[host_name("kernel_mul_mv_bf16_bf16_short")]] kernel mul_mv_t_t_short_t kernel_mul_mv_t_t_short<bfloat, bfloat>;
37103710
#endif
37113711

3712+
constant bool FC_rope_is_imrope [[function_constant(FC_ROPE + 0)]];
3713+
37123714
static float rope_yarn_ramp(const float low, const float high, const int i0) {
37133715
const float y = (i0 / 2 - low) / max(0.001f, high - low);
37143716
return 1.0f - min(1.0f, max(0.0f, y));
@@ -3889,14 +3891,26 @@ kernel void kernel_rope_multi(
38893891
const int sector = ic % sect_dims;
38903892

38913893
float theta_base;
3892-
if (sector < args.sect_0) {
3893-
theta_base = (float) pos[i2];
3894-
} else if (sector < sec_w01) {
3895-
theta_base = (float) pos[i2 + args.ne02];
3896-
} else if (sector < sec_w012) {
3897-
theta_base = (float) pos[i2 + args.ne02 * 2];
3894+
if (FC_rope_is_imrope) {
3895+
if (sector % 3 == 1 && sector < 3 * args.sect_1) { // h
3896+
theta_base = (float) pos[i2 + args.ne02 * 1];
3897+
} else if (sector % 3 == 2 && sector < 3 * args.sect_2) { // w
3898+
theta_base = (float) pos[i2 + args.ne02 * 2];
3899+
} else if (sector % 3 == 0 && sector < 3 * args.sect_0) { // t
3900+
theta_base = (float) pos[i2 + args.ne02 * 0];
3901+
} else { // e
3902+
theta_base = (float) pos[i2 + args.ne02 * 3];
3903+
}
38983904
} else {
3899-
theta_base = (float) pos[i2 + args.ne02 * 3];
3905+
if (sector < args.sect_0) {
3906+
theta_base = (float) pos[i2];
3907+
} else if (sector < sec_w01) {
3908+
theta_base = (float) pos[i2 + args.ne02 * 1];
3909+
} else if (sector < sec_w012) {
3910+
theta_base = (float) pos[i2 + args.ne02 * 2];
3911+
} else {
3912+
theta_base = (float) pos[i2 + args.ne02 * 3];
3913+
}
39003914
}
39013915
// end of mrope
39023916

ggml/src/ggml-sycl/rope.cpp

Lines changed: 30 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ static void rope_multi(const T * x, T * dst, const int ne0, const int ne1, const
119119
const size_t s2, const int n_dims, const int32_t * pos, const float freq_scale,
120120
const float ext_factor, const float attn_factor, const rope_corr_dims corr_dims,
121121
const float theta_scale, const float * freq_factors, const mrope_sections sections,
122-
const sycl::nd_item<3> & item_ct1) {
122+
const bool is_imrope, const sycl::nd_item<3> & item_ct1) {
123123
// get index pos
124124
const int i0 = 2 * (item_ct1.get_group(1) * item_ct1.get_local_range(1) + item_ct1.get_local_id(1));
125125
if (i0 >= ne0) {
@@ -143,17 +143,29 @@ static void rope_multi(const T * x, T * dst, const int ne0, const int ne1, const
143143

144144

145145
float theta_base = 0.0;
146-
if (sector < sections.v[0]) {
147-
theta_base = pos[channel_x]*sycl::pow(theta_scale, i0/2.0f);
148-
}
149-
else if (sector >= sections.v[0] && sector < sec_w) {
150-
theta_base = pos[channel_x + ne2 * 1]*sycl::pow(theta_scale, i0/2.0f);
151-
}
152-
else if (sector >= sec_w && sector < sec_w + sections.v[2]) {
153-
theta_base = pos[channel_x + ne2 * 2]*sycl::pow(theta_scale, i0/2.0f);
154-
}
155-
else if (sector >= sec_w + sections.v[2]) {
156-
theta_base = pos[channel_x + ne2 * 3]*sycl::pow(theta_scale, i0/2.0f);
146+
if (is_imrope) {
147+
if (sector % 3 == 1 && sector < 3 * sections.v[1]) {
148+
theta_base = pos[channel_x + ne2 * 1]*sycl::pow(theta_scale, i0/2.0f);
149+
} else if (sector % 3 == 2 && sector < 3 * sections.v[2]) {
150+
theta_base = pos[channel_x + ne2 * 2]*sycl::pow(theta_scale, i0/2.0f);
151+
} else if (sector % 3 == 0 && sector < 3 * sections.v[0]) {
152+
theta_base = pos[channel_x]*sycl::pow(theta_scale, i0/2.0f);
153+
} else {
154+
theta_base = pos[channel_x + ne2 * 3]*sycl::pow(theta_scale, i0/2.0f);
155+
}
156+
} else {
157+
if (sector < sections.v[0]) {
158+
theta_base = pos[channel_x]*sycl::pow(theta_scale, i0/2.0f);
159+
}
160+
else if (sector >= sections.v[0] && sector < sec_w) {
161+
theta_base = pos[channel_x + ne2 * 1]*sycl::pow(theta_scale, i0/2.0f);
162+
}
163+
else if (sector >= sec_w && sector < sec_w + sections.v[2]) {
164+
theta_base = pos[channel_x + ne2 * 2]*sycl::pow(theta_scale, i0/2.0f);
165+
}
166+
else if (sector >= sec_w + sections.v[2]) {
167+
theta_base = pos[channel_x + ne2 * 3]*sycl::pow(theta_scale, i0/2.0f);
168+
}
157169
}
158170

159171
const float freq_factor = has_ff ? freq_factors[i0 / 2] : 1.0f;
@@ -281,7 +293,7 @@ static void rope_multi_sycl(const T * x, T * dst, const int ne0, const int ne1,
281293
const size_t s2, const int n_dims, const int nr, const int32_t * pos,
282294
const float freq_scale, const float freq_base, const float ext_factor,
283295
const float attn_factor, const rope_corr_dims corr_dims, const float * freq_factors,
284-
const mrope_sections sections, queue_ptr stream) {
296+
const mrope_sections sections, const bool is_imrope, queue_ptr stream) {
285297
GGML_ASSERT(ne0 % 2 == 0);
286298
const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
287299
const int n_blocks_y = ceil_div(ne0, (2 * SYCL_ROPE_BLOCK_SIZE));
@@ -297,12 +309,12 @@ static void rope_multi_sycl(const T * x, T * dst, const int ne0, const int ne1,
297309
if (freq_factors == nullptr) {
298310
stream->parallel_for(nd_range, [=](sycl::nd_item<3> item_ct1) {
299311
rope_multi<T, false>(x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor,
300-
corr_dims, theta_scale, freq_factors, sections, item_ct1);
312+
corr_dims, theta_scale, freq_factors, sections, is_imrope, item_ct1);
301313
});
302314
} else {
303315
stream->parallel_for(nd_range, [=](sycl::nd_item<3> item_ct1) {
304316
rope_multi<T, true>(x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor,
305-
corr_dims, theta_scale, freq_factors, sections, item_ct1);
317+
corr_dims, theta_scale, freq_factors, sections, is_imrope, item_ct1);
306318
});
307319
}
308320
}
@@ -381,6 +393,7 @@ inline void ggml_sycl_op_rope(ggml_backend_sycl_context & ctx, ggml_tensor *dst)
381393

382394
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
383395
const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
396+
const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE;
384397
const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
385398

386399
if (is_mrope) {
@@ -422,11 +435,11 @@ inline void ggml_sycl_op_rope(ggml_backend_sycl_context & ctx, ggml_tensor *dst)
422435
if (dst->src[0]->type == GGML_TYPE_F16) {
423436
rope_multi_sycl((const sycl::half *)dst->src[0]->data, (sycl::half *)dst->data, ne00, ne01, ne02, s01,
424437
s02, n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
425-
freq_factors, sections, main_stream);
438+
freq_factors, sections, is_imrope, main_stream);
426439
} else if (dst->src[0]->type == GGML_TYPE_F32) {
427440
rope_multi_sycl((const float *) dst->src[0]->data, (float *) dst->data, ne00, ne01, ne02, s01, s02, n_dims,
428441
nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections,
429-
main_stream);
442+
is_imrope, main_stream);
430443
} else {
431444
GGML_ABORT("Fatal error: Tensor type unsupported!");
432445
}

tests/test-backend-ops.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7081,6 +7081,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
70817081
test_cases.emplace_back(new test_rope(type, {128, 12, 2, 1}, 20, GGML_ROPE_TYPE_IMROPE, 512, fs, ef, af, ff, v, fw));
70827082
test_cases.emplace_back(new test_rope(type, {128, 28, 2, 1}, 32, GGML_ROPE_TYPE_IMROPE, 512, fs, ef, af, ff, v, fw));
70837083
test_cases.emplace_back(new test_rope(type, { 80, 16, 2, 1}, 80, GGML_ROPE_TYPE_VISION, 512, fs, ef, af, ff, v, fw)); // rope_multi,m-rope (qwen2vl ViT)
7084+
test_cases.emplace_back(new test_rope(type, {128, 16, 2, 1}, 128, GGML_ROPE_TYPE_IMROPE, 512, fs, ef, af, ff, v, fw)); // rope_multi,m-rope (qwen3vl)
70847085
}
70857086

70867087
test_cases.emplace_back(new test_rope(type, { 64, 128, 2, 1}, 64, GGML_ROPE_TYPE_NEOX, 512, fs, ef, af, ff, v, fw)); // neox (falcon 40B)

0 commit comments

Comments
 (0)