Skip to content

Commit 8c189d3

Browse files
committed
rope complete
1 parent 5f83354 commit 8c189d3

File tree

3 files changed

+69
-38
lines changed

3 files changed

+69
-38
lines changed

ggml/src/ggml-webgpu/ggml-webgpu.cpp

Lines changed: 11 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -742,18 +742,21 @@ static void ggml_webgpu_rope(webgpu_context & ctx,
742742
const int inplace = ggml_webgpu_tensor_equal(src0, dst);
743743
const int has_freq_factor = (src2 != nullptr);
744744

745-
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
746745
const int n_dims = ((int32_t *) dst->op_params)[1];
747746
const int mode = ((int32_t *) dst->op_params)[2];
748747
const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
749748

749+
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
750750
memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
751751
memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
752752
memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
753753
memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
754754
memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
755755
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
756756

757+
int sections[4];
758+
memcpy(sections, (int32_t *) dst->op_params + 11, 4 * sizeof(int));
759+
757760
float theta_scale = powf(freq_base, -2.0f / n_dims);
758761

759762
float corr_dims[2];
@@ -780,7 +783,11 @@ static void ggml_webgpu_rope(webgpu_context & ctx,
780783
*(uint32_t *) &freq_scale,
781784
*(uint32_t *) &ext_factor,
782785
*(uint32_t *) &corr_dims[0],
783-
*(uint32_t *) &corr_dims[1]
786+
*(uint32_t *) &corr_dims[1],
787+
(uint32_t) sections[0],
788+
(uint32_t) sections[1],
789+
(uint32_t) sections[2],
790+
(uint32_t) sections[3]
784791
};
785792

786793
std::vector<wgpu::BindGroupEntry> entries = {
@@ -1461,22 +1468,8 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
14611468
supports_op = op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32;
14621469
break;
14631470
case GGML_OP_ROPE:
1464-
{
1465-
//std::cout << "ROPE op types: dst: " << ggml_type_name(op->type)
1466-
// << ", src0: " << (op->src[0] ? ggml_type_name(op->src[0]->type) : "null")
1467-
// << ", src1: " << (op->src[1] ? ggml_type_name(op->src[1]->type) : "null")
1468-
// << ", src2: " << (op->src[2] ? ggml_type_name(op->src[2]->type) : "null") << std::endl;
1469-
//std::cout << "ROPE op shapes: dst: op->ne[0]=" << op->ne[0] << ", ne[1]=" << op->ne[1] << ", ne[2]=" << op->ne[2]
1470-
// << ", ne[3]=" << op->ne[3] << std::endl;
1471-
//std::cout << "ROPE op shapes: src0: src0->ne[0]=" << op->src[0]->ne[0] << ", ne[1]=" << op->src[0]->ne[1]
1472-
// << ", ne[2]=" << op->src[0]->ne[2] << ", ne[3]=" << op->src[0]->ne[3] << std::endl;
1473-
//std::cout << "ROPE op shapes: src1: src1->ne[0]=" << op->src[1]->ne[0] << ", ne[1]=" << op->src[1]->ne[1]
1474-
// << ", ne[2]=" << op->src[1]->ne[2] << ", ne[3]=" << op->src[1]->ne[3] << std::endl;
1475-
1476-
const int mode = ((int32_t *) op->op_params)[2];
1477-
supports_op = (mode == 0 || mode == 2) && (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16);
1478-
break;
1479-
}
1471+
supports_op = op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16;
1472+
break;
14801473
default:
14811474
break;
14821475
}

ggml/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl

Lines changed: 57 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -159,10 +159,13 @@ struct Params {
159159
freq_scale: f32,
160160
ext_factor: f32,
161161
corr_dim0: f32,
162-
corr_dim1: f32
162+
corr_dim1: f32,
163+
sections0: u32,
164+
sections1: u32,
165+
sections2: u32,
166+
sections3: u32
163167
};
164168

165-
166169
@group(0) @binding(0)
167170
var<storage, read_write> src0: array<{{TYPE}}>;
168171

@@ -189,19 +192,21 @@ fn rope_yarn(theta_extrap: f32, i: u32) -> vec2<f32> {
189192
return vec2<f32>(cos(theta) * mscale, sin(theta) * mscale);
190193
}
191194

192-
fn pair_base(i0: u32) -> u32 {
193-
switch (params.mode) {
194-
case 0 { return i0; } // norm
195-
case 2 { return i0 / 2; } // neox
196-
default { return 1; }
195+
fn pair_base(i0: u32, div_2: bool) -> u32 {
196+
if (div_2) {
197+
return i0 / 2;
198+
} else {
199+
return i0;
197200
}
198201
}
199202

200-
fn pair_offset() -> u32 {
201-
switch (params.mode) {
202-
case 0 { return 1; } // norm
203-
case 2 { return params.n_dims / 2; } // neox
204-
default { return 1; }
203+
fn pair_offset(is_neox: bool, is_mrope: bool, is_vision: bool) -> u32 {
204+
if (is_vision) {
205+
return params.n_dims;
206+
} else if (is_neox || is_mrope) {
207+
return params.n_dims / 2;
208+
} else {
209+
return 1;
205210
}
206211
}
207212

@@ -213,6 +218,10 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
213218
return;
214219
}
215220

221+
let is_neox = bool(params.mode & 2);
222+
let is_mrope = bool(params.mode & 8);
223+
let is_vision = params.mode == 24;
224+
216225
var i = gid.x * 2; // start index for this thread
217226
let i3 = i / (params.ne2 * params.ne1 * params.ne0);
218227
i = i % (params.ne2 * params.ne1 * params.ne0);
@@ -224,20 +233,49 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
224233
let i_src_row = params.offset_src0 + i3 * params.stride_src03 + i2 * params.stride_src02 + i1 * params.stride_src01;
225234
let i_dst_row = params.offset_dst + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1;
226235

227-
if (i0 >= params.n_dims) {
228-
rotate(i_dst_row + i0, i_dst_row + i0 + 1, f32(src0[i_src_row + i0]), f32(src0[i_src_row + i0 + 1]));
236+
if (i0 >= params.n_dims && !is_vision) {
237+
let i_src = i_src_row + i0;
238+
let i_dst = i_dst_row + i0;
239+
rotate(i_dst, i_dst + 1, f32(src0[i_src]), f32(src0[i_src + 1]));
229240
return;
230241
}
231242

232-
let theta_base = f32(src1[params.offset_src1 + i2]) * pow(params.theta_scale, f32(i0)/2.0f);
243+
var theta_base_mult: u32 = 0;
244+
var theta_scale_pwr: u32 = i0 / 2;
245+
if (is_mrope) {
246+
let sect_dims = params.sections0 + params.sections1 + params.sections2 + params.sections3;
247+
let sec_w = params.sections1 + params.sections0;
248+
let sec_e = params.sections2 + sec_w;
249+
let sector = (i0 / 2) % sect_dims;
250+
if (sector >= params.sections0 && sector < sec_w) {
251+
theta_base_mult = 1;
252+
if (is_vision) {
253+
theta_scale_pwr = sector - params.sections0;
254+
}
255+
} else if (sector >= sec_w && sector < sec_e) {
256+
theta_base_mult = 2;
257+
if (is_vision) {
258+
theta_scale_pwr = sector - sec_w;
259+
}
260+
} else if (sector >= sec_e) {
261+
if (is_vision) {
262+
theta_scale_pwr = sector - sec_e;
263+
theta_scale_pwr = (i0 / 2) % sec_e;
264+
}
265+
theta_base_mult = 3;
266+
} else if (is_vision) {
267+
theta_scale_pwr = sector;
268+
}
269+
}
270+
let theta_base = f32(src1[params.offset_src1 + i2 + params.ne2 * theta_base_mult]) * pow(params.theta_scale, f32(theta_scale_pwr));
233271
let thetas = rope_yarn(theta_base/freq_factor(i0), i0);
234272

235-
let i_src = i_src_row + pair_base(i0);
236-
let i_dst = i_dst_row + pair_base(i0);
273+
let i_src = i_src_row + pair_base(i0, is_neox || is_mrope || is_vision);
274+
let i_dst = i_dst_row + pair_base(i0, is_neox || is_mrope || is_vision);
237275

238276
let x0 = f32(src0[i_src]);
239-
let x1 = f32(src0[i_src + pair_offset()]);
240-
rotate(i_dst, i_dst + pair_offset(), x0 * thetas.x - x1 * thetas.y, x0 * thetas.y + x1 * thetas.x);
277+
let x1 = f32(src0[i_src + pair_offset(is_neox, is_mrope, is_vision)]);
278+
rotate(i_dst, i_dst + pair_offset(is_neox, is_mrope, is_vision), x0 * thetas.x - x1 * thetas.y, x0 * thetas.y + x1 * thetas.x);
241279
}
242280

243281
#end(SHADER)

tests/test-backend-ops.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6488,7 +6488,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
64886488

64896489
// single inplace test per type/mode/ff
64906490
for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_F16}) {
6491-
for (int mode : {0, 2}) {
6491+
for (int mode : {0, 2, 8, 24}) {
64926492
for (bool ff : {false, true}) {
64936493
test_cases.emplace_back(new test_rope(type, {128, 32, 2, 1}, 128, mode, 512, 1.4245f, 0.7465f, 1.4245f, ff, 0, true, true));
64946494
}

0 commit comments

Comments
 (0)