Skip to content

Commit cbca610

Browse files
committed
vulkan: add imrope w/o check
1 parent 10ce7fb commit cbca610

File tree

3 files changed

+28
-12
lines changed

3 files changed

+28
-12
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1056,6 +1056,7 @@ struct vk_op_rope_push_constants {
10561056
uint32_t s1;
10571057
uint32_t s2;
10581058
int32_t sections[4];
1059+
uint32_t is_imrope;
10591060
uint32_t is_back;
10601061
uint32_t set_rows_stride;
10611062
};
@@ -9925,6 +9926,8 @@ static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, cons
99259926
memcpy(sections, (int32_t *) dst->op_params + 11, sizeof(int)*4);
99269927
}
99279928

9929+
const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE;
9930+
99289931
float corr_dims[2];
99299932
ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
99309933

@@ -9946,7 +9949,7 @@ static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, cons
99469949
(uint32_t)src0->ne[0], (uint32_t)n_dims, freq_scale, (uint32_t)src0->ne[1],
99479950
freq_base, ext_factor, attn_factor, {corr_dims[0], corr_dims[1]}, theta_scale,
99489951
src2 != nullptr, (uint32_t)src0->ne[2], s1, s2,
9949-
{ sections[0], sections[1], sections[2], sections[3] }, backprop, set_rows_stride,
9952+
{ sections[0], sections[1], sections[2], sections[3] }, is_imrope, backprop, set_rows_stride,
99509953
}, dryrun);
99519954
}
99529955

ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ layout (push_constant) uniform parameter {
2727
uint s1;
2828
uint s2;
2929
int sections[4];
30+
uint is_imrope;
3031
uint is_back;
3132
uint set_rows_stride;
3233
} p;

ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -32,17 +32,29 @@ void main() {
3232
const uint sector = (i0 / 2) % sect_dims;
3333

3434
float theta_base = 0.0;
35-
if (sector < p.sections[0]) {
36-
theta_base = data_pos[channel_x]*pow(p.theta_scale, i0/2.0f);
37-
}
38-
else if (sector >= p.sections[0] && sector < sec_w) {
39-
theta_base = data_pos[channel_x + ne2 * 1]*pow(p.theta_scale, i0/2.0f);
40-
}
41-
else if (sector >= sec_w && sector < sec_w + p.sections[2]) {
42-
theta_base = data_pos[channel_x + ne2 * 2]*pow(p.theta_scale, i0/2.0f);
43-
}
44-
else if (sector >= sec_w + p.sections[2]) {
45-
theta_base = data_pos[channel_x + ne2 * 3]*pow(p.theta_scale, i0/2.0f);
35+
if (p.is_imrope) {
36+
if (sector % 3 == 1 && sector < 3 * p.sections[1]) {
37+
theta_base = data_pos[channel_x + ne2 * 1]*pow(p.theta_scale, i0/2.0f);
38+
} else if (sector % 3 == 2 && sector < 3 * p.sections[2]) {
39+
theta_base = data_pos[channel_x + ne2 * 2]*pow(p.theta_scale, i0/2.0f);
40+
} else if (sector % 3 == 0 && sector < 3 * p.sections[0]) {
41+
theta_base = data_pos[channel_x]*pow(p.theta_scale, i0/2.0f);
42+
} else {
43+
theta_base = data_pos[channel_x + ne2 * 3]*pow(p.theta_scale, i0/2.0f);
44+
}
45+
} else {
46+
if (sector < p.sections[0]) {
47+
theta_base = data_pos[channel_x]*pow(p.theta_scale, i0/2.0f);
48+
}
49+
else if (sector >= p.sections[0] && sector < sec_w) {
50+
theta_base = data_pos[channel_x + ne2 * 1]*pow(p.theta_scale, i0/2.0f);
51+
}
52+
else if (sector >= sec_w && sector < sec_w + p.sections[2]) {
53+
theta_base = data_pos[channel_x + ne2 * 2]*pow(p.theta_scale, i0/2.0f);
54+
}
55+
else if (sector >= sec_w + p.sections[2]) {
56+
theta_base = data_pos[channel_x + ne2 * 3]*pow(p.theta_scale, i0/2.0f);
57+
}
4658
}
4759

4860
const float freq_factor = p.has_ff != 0 ? data_ff[i0/2] : 1.0f;

0 commit comments

Comments
 (0)