Skip to content

Commit 0bed5d8

Browse files
committed
webgpu: add imrope w/o check
1 parent 19a458f commit 0bed5d8

File tree

1 file changed

+31
-18
lines changed

1 file changed

+31
-18
lines changed

ggml/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl

Lines changed: 31 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,7 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
221221

222222
let is_neox = bool(params.mode & 2);
223223
let is_mrope = bool(params.mode & 8);
224+
let is_imrope = params.mode == 40;
224225
let is_vision = params.mode == 24;
225226

226227
var i = gid.x * 2; // start index for this thread
@@ -248,24 +249,36 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
248249
let sec_w = params.sections1 + params.sections0;
249250
let sec_e = params.sections2 + sec_w;
250251
let sector = (i0 / 2) % sect_dims;
251-
if (sector >= params.sections0 && sector < sec_w) {
252-
theta_base_mult = 1;
253-
if (is_vision) {
254-
theta_scale_pwr = sector - params.sections0;
255-
}
256-
} else if (sector >= sec_w && sector < sec_e) {
257-
theta_base_mult = 2;
258-
if (is_vision) {
259-
theta_scale_pwr = sector - sec_w;
260-
}
261-
} else if (sector >= sec_e) {
262-
if (is_vision) {
263-
theta_scale_pwr = sector - sec_e;
264-
theta_scale_pwr = (i0 / 2) % sec_e;
265-
}
266-
theta_base_mult = 3;
267-
} else if (is_vision) {
268-
theta_scale_pwr = sector;
252+
if (is_imrope) {
253+
if (sector % 3 == 1 && sector < 3 * params.sections1) {
254+
theta_base_mult = 1;
255+
} else if (sector % 3 == 2 && sector < 3 * params.sections2) {
256+
theta_base_mult = 2;
257+
} else if (sector % 3 == 0 && sector < 3 * params.sections0) {
258+
theta_base_mult = 0;
259+
} else {
260+
theta_base_mult = 3;
261+
}
262+
} else {
263+
if (sector >= params.sections0 && sector < sec_w) {
264+
theta_base_mult = 1;
265+
if (is_vision) {
266+
theta_scale_pwr = sector - params.sections0;
267+
}
268+
} else if (sector >= sec_w && sector < sec_e) {
269+
theta_base_mult = 2;
270+
if (is_vision) {
271+
theta_scale_pwr = sector - sec_w;
272+
}
273+
} else if (sector >= sec_e) {
274+
if (is_vision) {
275+
theta_scale_pwr = sector - sec_e;
276+
theta_scale_pwr = (i0 / 2) % sec_e;
277+
}
278+
theta_base_mult = 3;
279+
} else if (is_vision) {
280+
theta_scale_pwr = sector;
281+
}
269282
}
270283
}
271284
let theta_base = f32(src1[params.offset_src1 + i2 + params.ne2 * theta_base_mult]) * pow(params.theta_scale, f32(theta_scale_pwr));

0 commit comments

Comments
 (0)