Skip to content

Commit 89f6cef

Browse files
committed
Work on rope variants
1 parent 0dd415f commit 89f6cef

File tree

3 files changed

+195
-63
lines changed

3 files changed

+195
-63
lines changed

ggml/src/ggml-webgpu/ggml-webgpu.cpp

Lines changed: 50 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ struct webgpu_context_struct {
137137
wgpu::ComputePipeline mul_ip_pipeline[2];
138138
wgpu::ComputePipeline rms_norm_pipeline;
139139
wgpu::ComputePipeline rms_norm_ip_pipeline;
140-
wgpu::ComputePipeline rope_pipeline[2][2];
140+
wgpu::ComputePipeline rope_pipeline[2][2][2][2]; // type, mode, ff, inplace
141141

142142
size_t memset_bytes_per_thread;
143143

@@ -734,11 +734,17 @@ static void ggml_webgpu_rms_norm(webgpu_context & ctx, ggml_tensor * src, ggml_t
734734
ggml_backend_webgpu_build_and_enqueue(ctx, pipeline, params, entries, wg_x, ggml_op_name(dst->op));
735735
}
736736

737-
static void ggml_webgpu_rope(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * src2, ggml_tensor * dst) {
738-
bool in_place = ggml_webgpu_tensor_equal(src0, dst);
739-
int has_freq_factor = (src2 != nullptr);
740-
741-
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
737+
static void ggml_webgpu_rope(webgpu_context & ctx,
738+
ggml_tensor * src0,
739+
ggml_tensor * src1,
740+
ggml_tensor * src2,
741+
ggml_tensor * dst) {
742+
const int inplace = ggml_webgpu_tensor_equal(src0, dst);
743+
const int has_freq_factor = (src2 != nullptr);
744+
const int mode = ((int32_t *) dst->op_params)[2];
745+
const int is_neox = mode & GGML_ROPE_TYPE_NEOX;
746+
747+
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
742748
const int n_dims = ((int32_t *) dst->op_params)[1];
743749
const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
744750

@@ -757,30 +763,25 @@ static void ggml_webgpu_rope(webgpu_context & ctx, ggml_tensor * src0, ggml_tens
757763
std::vector<uint32_t> params = {
758764
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
759765
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)),
766+
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
767+
(uint32_t) (src0->nb[1] / ggml_type_size(src0->type)),
768+
(uint32_t) (src0->nb[2] / ggml_type_size(src0->type)),
769+
(uint32_t) (src0->nb[3] / ggml_type_size(src0->type)),
770+
(uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
771+
(uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
772+
(uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
773+
(uint32_t) ggml_nelements(src0) / 2,
774+
(uint32_t) src0->ne[0],
775+
(uint32_t) src0->ne[1],
776+
(uint32_t) src0->ne[2],
777+
(uint32_t) n_dims,
778+
*(uint32_t *) &theta_scale,
779+
*(uint32_t *) &attn_factor,
780+
*(uint32_t *) &freq_scale,
781+
*(uint32_t *) &ext_factor,
782+
*(uint32_t *) &corr_dims[0],
783+
*(uint32_t *) &corr_dims[1]
760784
};
761-
if (!in_place) {
762-
params.push_back((uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)));
763-
}
764-
params.push_back((uint32_t) (src0->nb[1] / ggml_type_size(src0->type)));
765-
params.push_back((uint32_t) (src0->nb[2] / ggml_type_size(src0->type)));
766-
params.push_back((uint32_t) (src0->nb[3] / ggml_type_size(src0->type)));
767-
if (!in_place) {
768-
params.push_back((uint32_t) (dst->nb[1] / ggml_type_size(dst->type)));
769-
params.push_back((uint32_t) (dst->nb[2] / ggml_type_size(dst->type)));
770-
params.push_back((uint32_t) (dst->nb[3] / ggml_type_size(dst->type)));
771-
}
772-
params.push_back((uint32_t) ggml_nelements(src0) / 2);
773-
params.push_back((uint32_t) src0->ne[0]);
774-
params.push_back((uint32_t) src0->ne[1]);
775-
params.push_back((uint32_t) src0->ne[2]);
776-
777-
params.push_back((uint32_t) n_dims);
778-
params.push_back(*(uint32_t *) &theta_scale);
779-
params.push_back(*(uint32_t *) &attn_factor);
780-
params.push_back(*(uint32_t *) &freq_scale);
781-
params.push_back(*(uint32_t *) &ext_factor);
782-
params.push_back(*(uint32_t *) &corr_dims[0]);
783-
params.push_back(*(uint32_t *) &corr_dims[1]);
784785

785786
std::vector<wgpu::BindGroupEntry> entries = {
786787
{ .binding = 0,
@@ -800,21 +801,16 @@ static void ggml_webgpu_rope(webgpu_context & ctx, ggml_tensor * src0, ggml_tens
800801
.offset = ggml_webgpu_tensor_align_offset(ctx, src2),
801802
.size = ggml_webgpu_tensor_binding_size(ctx, src2) });
802803
}
803-
if (!in_place) {
804+
if (!inplace) {
804805
entries.push_back({ .binding = dst_binding,
805806
.buffer = ggml_webgpu_tensor_buf(dst),
806807
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
807808
.size = ggml_webgpu_tensor_binding_size(ctx, dst) });
808809
}
809810

810-
wgpu::ComputePipeline pipeline;
811-
if (in_place) {
812-
pipeline = ctx->rope_pipeline[dst->type][has_freq_factor];
813-
} else {
814-
pipeline = ctx->rope_pipeline[dst->type][has_freq_factor];
815-
}
816-
size_t max_wg_size = ctx->max_wg_size_x;
817-
uint32_t wg_x = (ggml_nelements(src0) / 2 + max_wg_size - 1) / max_wg_size;
811+
wgpu::ComputePipeline pipeline = ctx->rope_pipeline[dst->type][is_neox][has_freq_factor][inplace];
812+
size_t max_wg_size = ctx->max_wg_size_x;
813+
uint32_t wg_x = (ggml_nelements(src0) / 2 + max_wg_size - 1) / max_wg_size;
818814
ggml_backend_webgpu_build_and_enqueue(ctx, pipeline, params, entries, wg_x, ggml_op_name(dst->op));
819815
}
820816

@@ -1290,10 +1286,22 @@ static void ggml_webgpu_init_rms_norm_pipeline(webgpu_context & webgpu_ctx) {
12901286

12911287
static void ggml_webgpu_init_rope_pipeline(webgpu_context & webgpu_ctx) {
12921288
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx);
1293-
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rope_pipeline[GGML_TYPE_F32][0], wgsl_rope_f32_norm, "rope_f32_norm", constants);
1294-
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rope_pipeline[GGML_TYPE_F32][1], wgsl_rope_f32_norm_ff, "rope_f32_norm_ff", constants);
1295-
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rope_pipeline[GGML_TYPE_F16][0], wgsl_rope_f16_norm, "rope_f16_norm", constants);
1296-
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rope_pipeline[GGML_TYPE_F16][1], wgsl_rope_f16_norm_ff, "rope_f16_norm_ff", constants);
1289+
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rope_pipeline[GGML_TYPE_F32][0][0][0],
1290+
wgsl_rope_f32_norm, "rope_f32_norm", constants);
1291+
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rope_pipeline[GGML_TYPE_F32][0][0][1],
1292+
wgsl_rope_f32_norm_inplace, "rope_f32_norm_inplace", constants);
1293+
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rope_pipeline[GGML_TYPE_F32][0][1][0],
1294+
wgsl_rope_f32_norm_ff, "rope_f32_norm_ff", constants);
1295+
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rope_pipeline[GGML_TYPE_F32][0][1][1],
1296+
wgsl_rope_f32_norm_ff_inplace, "rope_f32_norm_ff_inplace", constants);
1297+
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rope_pipeline[GGML_TYPE_F16][0][0][0],
1298+
wgsl_rope_f16_norm, "rope_f16_norm", constants);
1299+
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rope_pipeline[GGML_TYPE_F16][0][0][1],
1300+
wgsl_rope_f16_norm_inplace, "rope_f16_norm_inplace", constants);
1301+
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rope_pipeline[GGML_TYPE_F16][0][1][0],
1302+
wgsl_rope_f16_norm_ff, "rope_f16_norm_ff", constants);
1303+
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rope_pipeline[GGML_TYPE_F16][0][1][1],
1304+
wgsl_rope_f16_norm_ff_inplace, "rope_f16_norm_ff_inplace", constants);
12971305
}
12981306

12991307
static ggml_backend_t ggml_backend_webgpu_device_init(ggml_backend_dev_t dev, const char * params) {

ggml/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl

Lines changed: 118 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -6,54 +6,136 @@
66
"REPLS": {
77
"TYPE" : "f32",
88
},
9-
"DECLS": ["NO_FREQ_FAC"]
9+
"DECLS": ["NO_FF_BINDINGS", "NO_FF_FUNC", "NORM", "ROTATE"]
10+
},
11+
{
12+
"SHADER_SUFFIX": "f32_norm_inplace",
13+
"REPLS": {
14+
"TYPE" : "f32",
15+
},
16+
"DECLS": ["NO_FF_BINDINGS_INPLACE", "NO_FF_FUNC", "NORM", "ROTATE_INPLACE"]
1017
},
1118
{
1219
"SHADER_SUFFIX": "f16_norm",
1320
"REPLS": {
1421
"TYPE" : "f16",
1522
},
16-
"DECLS": ["NO_FREQ_FAC"]
23+
"DECLS": ["NO_FF_BINDINGS", "NO_FF_FUNC", "NORM", "ROTATE"]
24+
},
25+
{
26+
"SHADER_SUFFIX": "f16_norm_inplace",
27+
"REPLS": {
28+
"TYPE" : "f16",
29+
},
30+
"DECLS": ["NO_FF_BINDINGS_INPLACE", "NO_FF_FUNC", "NORM", "ROTATE_INPLACE"]
1731
},
1832
{
1933
"SHADER_SUFFIX": "f32_norm_ff",
2034
"REPLS": {
2135
"TYPE" : "f32",
2236
},
23-
"DECLS": ["FREQ_FAC"]
37+
"DECLS": ["FF_BINDINGS", "FF_FUNC", "NORM", "ROTATE"]
38+
},
39+
{
40+
"SHADER_SUFFIX": "f32_norm_ff_inplace",
41+
"REPLS": {
42+
"TYPE" : "f32",
43+
},
44+
"DECLS": ["FF_BINDINGS_INPLACE", "FF_FUNC", "NORM", "ROTATE_INPLACE"]
2445
},
2546
{
2647
"SHADER_SUFFIX": "f16_norm_ff",
2748
"REPLS": {
2849
"TYPE" : "f16",
2950
},
30-
"DECLS": ["FREQ_FAC"]
51+
"DECLS": ["FF_BINDINGS", "FF_FUNC", "NORM", "ROTATE"]
52+
},
53+
{
54+
"SHADER_SUFFIX": "f16_norm_ff_inplace",
55+
"REPLS": {
56+
"TYPE" : "f16",
57+
},
58+
"DECLS": ["FF_BINDINGS_INPLACE", "FF_FUNC", "NORM", "ROTATE_INPLACE"]
59+
},
60+
61+
{
62+
"SHADER_SUFFIX": "f32_neox",
63+
"REPLS": {
64+
"TYPE" : "f32",
65+
},
66+
"DECLS": ["NO_FF_BINDINGS", "NO_FF_FUNC", "NEOX", "ROTATE"]
67+
},
68+
{
69+
"SHADER_SUFFIX": "f16_neox",
70+
"REPLS": {
71+
"TYPE" : "f16",
72+
},
73+
"DECLS": ["NO_FF_BINDINGS", "NO_FF_FUNC", "NEOX", "ROTATE"]
74+
},
75+
{
76+
"SHADER_SUFFIX": "f32_neox_ff",
77+
"REPLS": {
78+
"TYPE" : "f32",
79+
},
80+
"DECLS": ["FF_BINDINGS", "FF_FUNC", "NEOX", "ROTATE"]
81+
},
82+
{
83+
"SHADER_SUFFIX": "f16_neox_ff",
84+
"REPLS": {
85+
"TYPE" : "f16",
86+
},
87+
"DECLS": ["FF_BINDINGS", "FF_FUNC", "NEOX", "ROTATE"]
3188
}
3289
]
3390

3491
#end(VARIANTS)
3592

3693
#define(DECLS)
3794

38-
#decl(NO_FREQ_FAC)
95+
#decl(ROTATE)
96+
fn rotate(i_dst: u32, out0: f32, out1: f32) {
97+
dst[i_dst] = {{TYPE}}(out0);
98+
dst[i_dst + pair_offset()] = {{TYPE}}(out1);
99+
}
100+
#enddecl(ROTATE)
39101

102+
#decl(ROTATE_INPLACE)
103+
fn rotate(i_dst: u32, out0: f32, out1: f32) {
104+
src0[i_dst] = {{TYPE}}(out0);
105+
src0[i_dst + pair_offset()] = {{TYPE}}(out1);
106+
}
107+
#enddecl(ROTATE_INPLACE)
108+
109+
#decl(NO_FF_FUNC)
40110
fn freq_factor(i: u32) -> f32 {
41111
return 1.0f;
42112
}
113+
#enddecl(NO_FF_FUNC)
114+
115+
#decl(FF_FUNC)
116+
fn freq_factor(i: u32) -> f32 {
117+
return src2[i/2];
118+
}
119+
#enddecl(FF_FUNC)
120+
121+
#decl(NO_FF_BINDINGS)
43122

44123
@group(0) @binding(2)
45124
var<storage, read_write> dst: array<{{TYPE}}>;
46125

47126
@group(0) @binding(3)
48127
var<uniform> params: Params;
49128

50-
#enddecl(NO_FREQ_FAC)
129+
#enddecl(NO_FF_BINDINGS)
51130

52-
#decl(FREQ_FAC)
131+
#decl(NO_FF_BINDINGS_INPLACE)
53132

54-
fn freq_factor(i: u32) -> f32 {
55-
return src2[i/2];
56-
}
133+
@group(0) @binding(2)
134+
var<uniform> params: Params;
135+
136+
#enddecl(NO_FF_BINDINGS_INPLACE)
137+
138+
#decl(FF_BINDINGS)
57139

58140
@group(0) @binding(2)
59141
var<storage, read_write> src2: array<f32>;
@@ -64,7 +146,29 @@ var<storage, read_write> dst: array<{{TYPE}}>;
64146
@group(0) @binding(4)
65147
var<uniform> params: Params;
66148

67-
#enddecl(FREQ_FAC)
149+
#enddecl(FF_BINDINGS)
150+
151+
#decl(FF_BINDINGS_INPLACE)
152+
153+
@group(0) @binding(2)
154+
var<storage, read_write> src2: array<f32>;
155+
156+
@group(0) @binding(3)
157+
var<uniform> params: Params;
158+
159+
#enddecl(FF_BINDINGS_INPLACE)
160+
161+
#decl(NORM)
162+
fn pair_offset() -> u32 {
163+
return 1;
164+
}
165+
#enddecl(NORM)
166+
167+
#decl(NEOX)
168+
fn pair_offset() -> u32 {
169+
return params.n_dims / 2;
170+
}
171+
#enddecl(NEOX)
68172

69173
#end(DECLS)
70174

@@ -146,18 +250,16 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
146250
let i_dst = params.offset_dst + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1 + i0;
147251

148252
if (i0 >= params.n_dims) {
149-
dst[i_dst] = src0[i_src];
150-
dst[i_dst + 1] = src0[i_src + 1];
253+
rotate(i_dst, f32(src0[i_src]), f32(src0[i_src + 1]));
151254
return;
152255
}
153256

154257
let theta_base = f32(src1[params.offset_src1 + i2]) * pow(params.theta_scale, f32(i0)/2.0f);
155258
let thetas = rope_yarn(theta_base/freq_factor(i0), i0);
156259

157260
let x0 = f32(src0[i_src]);
158-
let x1 = f32(src0[i_src + 1]);
159-
dst[i_dst] = {{TYPE}}(x0 * thetas.x - x1 * thetas.y);
160-
dst[i_dst + 1] = {{TYPE}}(x0 * thetas.y + x1 * thetas.x);
261+
let x1 = f32(src0[i_src + pair_offset()]);
262+
rotate(i_dst, x0 * thetas.x - x1 * thetas.y, x0 * thetas.y + x1 * thetas.x);
161263
}
162264

163265
#end(SHADER)

0 commit comments

Comments
 (0)