Skip to content

Commit f0fc822

Browse files
committed
Work on rope
1 parent d304f45 commit f0fc822

File tree

5 files changed

+288
-12
lines changed

5 files changed

+288
-12
lines changed

ggml/src/ggml-webgpu/ggml-webgpu.cpp

Lines changed: 118 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,7 @@ struct webgpu_context_struct {
137137
wgpu::ComputePipeline mul_ip_pipeline[2];
138138
wgpu::ComputePipeline rms_norm_pipeline;
139139
wgpu::ComputePipeline rms_norm_ip_pipeline;
140+
wgpu::ComputePipeline rope_pipeline[2][2];
140141

141142
size_t memset_bytes_per_thread;
142143

@@ -693,9 +694,6 @@ static void ggml_webgpu_binary_op(webgpu_context & ctx,
693694
static void ggml_webgpu_rms_norm(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
694695
bool in_place = ggml_webgpu_tensor_equal(src, dst);
695696

696-
uint32_t eps;
697-
memcpy(&eps, dst->op_params, sizeof(float));
698-
699697
std::vector<uint32_t> params = {
700698
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
701699
};
@@ -714,7 +712,7 @@ static void ggml_webgpu_rms_norm(webgpu_context & ctx, ggml_tensor * src, ggml_t
714712
params.push_back((uint32_t) src->ne[1]);
715713
params.push_back((uint32_t) src->ne[2]);
716714
params.push_back((uint32_t) src->ne[3]);
717-
params.push_back(eps); // epsilon, will be bitcast to float in shader
715+
params.push_back(*(uint32_t *) dst->op_params); // epsilon, treated as f32 in the shader
718716

719717
std::vector<wgpu::BindGroupEntry> entries = {
720718
{ .binding = 0,
@@ -740,6 +738,90 @@ static void ggml_webgpu_rms_norm(webgpu_context & ctx, ggml_tensor * src, ggml_t
740738
ggml_backend_webgpu_build_and_enqueue(ctx, pipeline, params, entries, wg_x, ggml_op_name(dst->op));
741739
}
742740

741+
static void ggml_webgpu_rope(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * src2, ggml_tensor * dst) {
742+
bool in_place = ggml_webgpu_tensor_equal(src0, dst);
743+
int has_freq_factor = (src2 != nullptr);
744+
745+
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
746+
const int n_dims = ((int32_t *) dst->op_params)[1];
747+
const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
748+
749+
memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
750+
memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
751+
memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
752+
memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
753+
memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
754+
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
755+
756+
float theta_scale = powf(freq_base, -2.0f / n_dims);
757+
758+
float corr_dims[2];
759+
ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
760+
761+
std::vector<uint32_t> params = {
762+
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
763+
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)),
764+
};
765+
if (!in_place) {
766+
params.push_back((uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)));
767+
}
768+
params.push_back((uint32_t) (src0->nb[1] / ggml_type_size(src0->type)));
769+
params.push_back((uint32_t) (src0->nb[2] / ggml_type_size(src0->type)));
770+
params.push_back((uint32_t) (src0->nb[3] / ggml_type_size(src0->type)));
771+
if (!in_place) {
772+
params.push_back((uint32_t) (dst->nb[1] / ggml_type_size(dst->type)));
773+
params.push_back((uint32_t) (dst->nb[2] / ggml_type_size(dst->type)));
774+
params.push_back((uint32_t) (dst->nb[3] / ggml_type_size(dst->type)));
775+
}
776+
params.push_back((uint32_t) ggml_nelements(src0) / 2);
777+
params.push_back((uint32_t) src0->ne[0]);
778+
params.push_back((uint32_t) src0->ne[1]);
779+
params.push_back((uint32_t) src0->ne[2]);
780+
781+
params.push_back((uint32_t) n_dims);
782+
params.push_back(*(uint32_t *) &theta_scale);
783+
params.push_back(*(uint32_t *) &attn_factor);
784+
params.push_back(*(uint32_t *) &freq_scale);
785+
params.push_back(*(uint32_t *) &ext_factor);
786+
params.push_back(*(uint32_t *) &corr_dims[0]);
787+
params.push_back(*(uint32_t *) &corr_dims[1]);
788+
789+
std::vector<wgpu::BindGroupEntry> entries = {
790+
{ .binding = 0,
791+
.buffer = ggml_webgpu_tensor_buf(src0),
792+
.offset = ggml_webgpu_tensor_align_offset(ctx, src0),
793+
.size = ggml_webgpu_tensor_binding_size(ctx, src0) },
794+
{ .binding = 1,
795+
.buffer = ggml_webgpu_tensor_buf(src1),
796+
.offset = ggml_webgpu_tensor_align_offset(ctx, src1),
797+
.size = ggml_webgpu_tensor_binding_size(ctx, src1) }
798+
};
799+
uint32_t dst_binding = 2;
800+
if (has_freq_factor) {
801+
dst_binding = 3;
802+
entries.push_back({ .binding = 2,
803+
.buffer = ggml_webgpu_tensor_buf(src2),
804+
.offset = ggml_webgpu_tensor_align_offset(ctx, src2),
805+
.size = ggml_webgpu_tensor_binding_size(ctx, src2) });
806+
}
807+
if (!in_place) {
808+
entries.push_back({ .binding = dst_binding,
809+
.buffer = ggml_webgpu_tensor_buf(dst),
810+
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
811+
.size = ggml_webgpu_tensor_binding_size(ctx, dst) });
812+
}
813+
814+
wgpu::ComputePipeline pipeline;
815+
if (in_place) {
816+
pipeline = ctx->rope_pipeline[dst->type][has_freq_factor];
817+
} else {
818+
pipeline = ctx->rope_pipeline[dst->type][has_freq_factor];
819+
}
820+
size_t max_wg_size = ctx->max_wg_size_x;
821+
uint32_t wg_x = (ggml_nelements(src0) / 2 + max_wg_size - 1) / max_wg_size;
822+
ggml_backend_webgpu_build_and_enqueue(ctx, pipeline, params, entries, wg_x, ggml_op_name(dst->op));
823+
}
824+
743825
// Returns true if node has enqueued work into the queue, false otherwise
744826
static bool ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node) {
745827
if (ggml_is_empty(node)) {
@@ -749,6 +831,7 @@ static bool ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node) {
749831

750832
ggml_tensor * src0 = node->src[0];
751833
ggml_tensor * src1 = node->src[1];
834+
ggml_tensor * src2 = node->src[2];
752835

753836
switch (node->op) {
754837
// no-ops
@@ -787,6 +870,9 @@ static bool ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node) {
787870
case GGML_OP_RMS_NORM:
788871
ggml_webgpu_rms_norm(ctx, src0, node);
789872
break;
873+
case GGML_OP_ROPE:
874+
ggml_webgpu_rope(ctx, src0, src1, src2, node);
875+
break;
790876
default:
791877
return false;
792878
}
@@ -1206,6 +1292,14 @@ static void ggml_webgpu_init_rms_norm_pipeline(webgpu_context & webgpu_ctx) {
12061292
"rms_norm_in_place", constants);
12071293
}
12081294

1295+
static void ggml_webgpu_init_rope_pipeline(webgpu_context & webgpu_ctx) {
1296+
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx);
1297+
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rope_pipeline[GGML_TYPE_F32][0], wgsl_rope_f32_norm, "rope_f32_norm", constants);
1298+
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rope_pipeline[GGML_TYPE_F32][1], wgsl_rope_f32_norm_ff, "rope_f32_norm_ff", constants);
1299+
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rope_pipeline[GGML_TYPE_F16][0], wgsl_rope_f16_norm, "rope_f16_norm", constants);
1300+
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rope_pipeline[GGML_TYPE_F16][1], wgsl_rope_f16_norm_ff, "rope_f16_norm_ff", constants);
1301+
}
1302+
12091303
static ggml_backend_t ggml_backend_webgpu_device_init(ggml_backend_dev_t dev, const char * params) {
12101304
GGML_UNUSED(params);
12111305

@@ -1287,6 +1381,8 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
12871381

12881382
ggml_tensor * src0 = op->src[0];
12891383
ggml_tensor * src1 = op->src[1];
1384+
ggml_tensor * src2 = op->src[1];
1385+
12901386
// on smaller devices (or CI), tensors may be larger than the max storage buffer size
12911387
if (ggml_nbytes(op) > webgpu_ctx->limits.maxStorageBufferBindingSize ||
12921388
(src0 != nullptr && ggml_nbytes(src0) > webgpu_ctx->limits.maxStorageBufferBindingSize) ||
@@ -1360,6 +1456,23 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
13601456
case GGML_OP_RMS_NORM:
13611457
supports_op = op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32;
13621458
break;
1459+
case GGML_OP_ROPE:
1460+
{
1461+
//std::cout << "ROPE op types: dst: " << ggml_type_name(op->type)
1462+
// << ", src0: " << (op->src[0] ? ggml_type_name(op->src[0]->type) : "null")
1463+
// << ", src1: " << (op->src[1] ? ggml_type_name(op->src[1]->type) : "null")
1464+
// << ", src2: " << (op->src[2] ? ggml_type_name(op->src[2]->type) : "null") << std::endl;
1465+
//std::cout << "ROPE op shapes: dst: op->ne[0]=" << op->ne[0] << ", ne[1]=" << op->ne[1] << ", ne[2]=" << op->ne[2]
1466+
// << ", ne[3]=" << op->ne[3] << std::endl;
1467+
//std::cout << "ROPE op shapes: src0: src0->ne[0]=" << op->src[0]->ne[0] << ", ne[1]=" << op->src[0]->ne[1]
1468+
// << ", ne[2]=" << op->src[0]->ne[2] << ", ne[3]=" << op->src[0]->ne[3] << std::endl;
1469+
//std::cout << "ROPE op shapes: src1: src1->ne[0]=" << op->src[1]->ne[0] << ", ne[1]=" << op->src[1]->ne[1]
1470+
// << ", ne[2]=" << op->src[1]->ne[2] << ", ne[3]=" << op->src[1]->ne[3] << std::endl;
1471+
1472+
const int mode = ((int32_t *) op->op_params)[2];
1473+
supports_op = mode == 0 && (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16);
1474+
break;
1475+
}
13631476
default:
13641477
break;
13651478
}
@@ -1486,6 +1599,7 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t
14861599
ggml_webgpu_init_add_pipeline(ctx);
14871600
ggml_webgpu_init_mul_pipeline(ctx);
14881601
ggml_webgpu_init_rms_norm_pipeline(ctx);
1602+
ggml_webgpu_init_rope_pipeline(ctx);
14891603

14901604
#ifdef GGML_WEBGPU_DEBUG
14911605
// Initialize debug buffers

ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -88,14 +88,14 @@ def generate_variants(fname, input_dir, output_dir, outfile):
8888
raise ValueError(f"DECLS key '{key}' not found.")
8989
decls_code += decls_map[key] + "\n\n"
9090

91-
shader_variant = replace_placeholders(shader_template, variant["REPLS"])
92-
final_shader = re.sub(r'\bDECLS\b', decls_code, shader_variant)
91+
final_shader = re.sub(r'\bDECLS\b', decls_code, shader_template)
92+
final_shader = replace_placeholders(final_shader, variant["REPLS"])
9393
final_shader = expand_includes(final_shader, input_dir)
9494

9595
if "SRC0_TYPE" in variant["REPLS"] and "SRC1_TYPE" in variant["REPLS"]:
9696
output_name = f"{shader_base_name}_" + "_".join([variant["REPLS"]["SRC0_TYPE"], variant["REPLS"]["SRC1_TYPE"]])
97-
elif "TYPE_SUFFIX" in variant["REPLS"]:
98-
output_name = f"{shader_base_name}_" + variant["REPLS"]["TYPE_SUFFIX"]
97+
elif "SHADER_SUFFIX" in variant:
98+
output_name = f"{shader_base_name}_" + variant["SHADER_SUFFIX"]
9999
elif "TYPE" in variant["REPLS"]:
100100
output_name = f"{shader_base_name}_" + variant["REPLS"]["TYPE"]
101101
else:

ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@
22

33
[
44
{
5+
"SHADER_SUFFIX": "f32_vec",
56
"REPLS": {
67
"TYPE" : "vec4<f32>",
7-
"TYPE_SUFFIX": "f32_vec",
88
"DST_TYPE": "vec4<f32>",
99
"BLOCK_SIZE": 4
1010
},

ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ struct Params {
2323
ne2: u32,
2424
ne3: u32,
2525

26-
eps: u32
26+
eps: f32
2727
};
2828

2929
@group(0) @binding(2)
@@ -49,8 +49,7 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
4949
for (var j: u32 = 0; j < params.ne0; j++) {
5050
sum += src[i_src_row + j] * src[i_src_row + j];
5151
}
52-
let eps = bitcast<f32>(params.eps);
53-
let scale = 1.0/sqrt(sum/f32(params.ne0) + eps);
52+
let scale = 1.0/sqrt(sum/f32(params.ne0) + params.eps);
5453
for (var j: u32 = 0; j < params.ne0; j++) {
5554
dst[i_dst_row + j] = scale * src[i_src_row + j];
5655
}

0 commit comments

Comments
 (0)