Skip to content

Commit eee4e4d

Browse files
committed
Update cpy shader to handle cont/more types
1 parent a7c9d33 commit eee4e4d

File tree

4 files changed

+123
-66
lines changed

4 files changed

+123
-66
lines changed

ggml/src/ggml-webgpu/ggml-webgpu.cpp

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ struct webgpu_context_struct {
130130
wgpu::ComputePipeline set_rows_pipeline;
131131
wgpu::ComputePipeline get_rows_pipeline[30];
132132
wgpu::ComputePipeline get_rows_f32_no_vec_pipeline;
133-
wgpu::ComputePipeline cpy_pipeline;
133+
wgpu::ComputePipeline cpy_pipeline[2][2]; // src type, dst type
134134
wgpu::ComputePipeline add_pipeline[2][2]; // type, inplace
135135
wgpu::ComputePipeline sub_pipeline[2][2]; // type, inplace
136136
wgpu::ComputePipeline mul_pipeline[2][2]; // type, inplace
@@ -491,8 +491,9 @@ static void ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, ggml_tensor
491491
(uint32_t) (src->nb[2] / ggml_type_size(src->type)), (uint32_t) (src->nb[3] / ggml_type_size(src->type)),
492492
(uint32_t) (dst->nb[0] / ggml_type_size(dst->type)), (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
493493
(uint32_t) (dst->nb[2] / ggml_type_size(dst->type)), (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
494-
// Logical shape — same for both tensors even if permuted
495-
(uint32_t) src->ne[0], (uint32_t) src->ne[1], (uint32_t) src->ne[2], (uint32_t) src->ne[3]
494+
// Logical shapes
495+
(uint32_t) src->ne[0], (uint32_t) src->ne[1], (uint32_t) src->ne[2], (uint32_t) dst->ne[0],
496+
(uint32_t) dst->ne[1], (uint32_t) dst->ne[2]
496497
};
497498

498499
std::vector<wgpu::BindGroupEntry> entries = {
@@ -508,7 +509,8 @@ static void ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, ggml_tensor
508509

509510
size_t max_wg_size = ctx->max_wg_size_x;
510511
uint32_t wg_x = (ne + max_wg_size - 1) / max_wg_size;
511-
ggml_backend_webgpu_build_and_enqueue(ctx, ctx->cpy_pipeline, params, entries, wg_x, ggml_op_name(dst->op));
512+
ggml_backend_webgpu_build_and_enqueue(ctx, ctx->cpy_pipeline[src->type][dst->type], params, entries, wg_x,
513+
ggml_op_name(dst->op));
512514
}
513515

514516
static void ggml_webgpu_set_rows(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * idx, ggml_tensor * dst) {
@@ -930,6 +932,7 @@ static bool ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node) {
930932
case GGML_OP_RESHAPE:
931933
return false;
932934
case GGML_OP_CPY:
935+
case GGML_OP_CONT:
933936
ggml_webgpu_cpy(ctx, src0, node);
934937
break;
935938
case GGML_OP_SET_ROWS:
@@ -1360,8 +1363,15 @@ static void ggml_webgpu_init_get_rows_pipeline(webgpu_context & webgpu_ctx) {
13601363
}
13611364

13621365
static void ggml_webgpu_init_cpy_pipeline(webgpu_context & webgpu_ctx) {
1363-
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->cpy_pipeline, wgsl_cpy, "cpy",
1364-
ggml_webgpu_max_wg_size_entry(webgpu_ctx));
1366+
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx);
1367+
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->cpy_pipeline[GGML_TYPE_F32][GGML_TYPE_F32],
1368+
wgsl_cpy_f32_f32, "cpy_f32_f32", constants);
1369+
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->cpy_pipeline[GGML_TYPE_F32][GGML_TYPE_F16],
1370+
wgsl_cpy_f32_f16, "cpy_f32_f16", constants);
1371+
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->cpy_pipeline[GGML_TYPE_F16][GGML_TYPE_F32],
1372+
wgsl_cpy_f16_f32, "cpy_f16_f32", constants);
1373+
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->cpy_pipeline[GGML_TYPE_F16][GGML_TYPE_F16],
1374+
wgsl_cpy_f16_f16, "cpy_f16_f16", constants);
13651375
}
13661376

13671377
static void ggml_webgpu_init_add_pipeline(webgpu_context & webgpu_ctx) {
@@ -1608,6 +1618,10 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
16081618
(src1->type == op->type);
16091619
break;
16101620
case GGML_OP_CPY:
1621+
case GGML_OP_CONT:
1622+
supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) &&
1623+
(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
1624+
break;
16111625
case GGML_OP_SET_ROWS:
16121626
supports_op = (op->type == GGML_TYPE_F16 && src0->type == GGML_TYPE_F32);
16131627
break;
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
#define(VARIANTS)
2+
3+
[
4+
{
5+
"REPLS": {
6+
"SRC_TYPE": "f32",
7+
"DST_TYPE": "f32"
8+
}
9+
},
10+
{
11+
"REPLS": {
12+
"SRC_TYPE": "f32",
13+
"DST_TYPE": "f16"
14+
}
15+
},
16+
{
17+
"REPLS": {
18+
"SRC_TYPE": "f16",
19+
"DST_TYPE": "f16"
20+
}
21+
},
22+
{
23+
"REPLS": {
24+
"SRC_TYPE": "f16",
25+
"DST_TYPE": "f32"
26+
}
27+
}
28+
]
29+
30+
#end(VARIANTS)
31+
32+
#define(SHADER)
33+
enable f16;
34+
35+
@group(0) @binding(0)
36+
var<storage, read_write> src: array<{{SRC_TYPE}}>;
37+
38+
@group(0) @binding(1)
39+
var<storage, read_write> dst: array<{{DST_TYPE}}>;
40+
41+
struct Params {
42+
ne: u32, // total number of elements
43+
offset_src: u32, // in elements
44+
offset_dst: u32, // in elements
45+
46+
// Strides (in elements) — may be permuted
47+
stride_src0: u32,
48+
stride_src1: u32,
49+
stride_src2: u32,
50+
stride_src3: u32,
51+
52+
stride_dst0: u32,
53+
stride_dst1: u32,
54+
stride_dst2: u32,
55+
stride_dst3: u32,
56+
57+
// Logical shapes
58+
src_ne0: u32,
59+
src_ne1: u32,
60+
src_ne2: u32,
61+
62+
dst_ne0: u32,
63+
dst_ne1: u32,
64+
dst_ne2: u32
65+
};
66+
67+
@group(0) @binding(2)
68+
var<uniform> params: Params;
69+
70+
override wg_size: u32;
71+
@compute @workgroup_size(wg_size)
72+
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
73+
if (gid.x >= params.ne) {
74+
return;
75+
}
76+
77+
var i = gid.x;
78+
let i3 = i / (params.src_ne2 * params.src_ne1 * params.src_ne0);
79+
i = i % (params.src_ne2 * params.src_ne1 * params.src_ne0);
80+
let i2 = i / (params.src_ne1 * params.src_ne0);
81+
i = i % (params.src_ne1 * params.src_ne0);
82+
let i1 = i / params.src_ne0;
83+
let i0 = i % params.src_ne0;
84+
85+
var j = gid.x;
86+
let j3 = j / (params.dst_ne2 * params.dst_ne1 * params.dst_ne0);
87+
j = j % (params.dst_ne2 * params.dst_ne1 * params.dst_ne0);
88+
let j2 = j / (params.dst_ne1 * params.dst_ne0);
89+
j = j % (params.dst_ne1 * params.dst_ne0);
90+
let j1 = j / params.dst_ne0;
91+
let j0 = j % params.dst_ne0;
92+
93+
let src_idx = i0 * params.stride_src0 + i1 * params.stride_src1 +
94+
i2 * params.stride_src2 + i3 * params.stride_src3;
95+
96+
let dst_idx = j0 * params.stride_dst0 + j1 * params.stride_dst1 +
97+
j2 * params.stride_dst2 + j3 * params.stride_dst3;
98+
99+
dst[params.offset_dst + dst_idx] = {{DST_TYPE}}((src[params.offset_src + src_idx]));
100+
}
101+
#end(SHADER)

ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl

Lines changed: 0 additions & 60 deletions
This file was deleted.

ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,8 @@ def generate_variants(fname, input_dir, output_dir, outfile):
9999
output_name = f"{shader_base_name}_" + variant["SHADER_SUFFIX"]
100100
elif "REPLS" in variant and "SRC0_TYPE" in variant["REPLS"] and "SRC1_TYPE" in variant["REPLS"]:
101101
output_name = f"{shader_base_name}_" + "_".join([variant["REPLS"]["SRC0_TYPE"], variant["REPLS"]["SRC1_TYPE"]])
102+
elif "REPLS" in variant and "SRC_TYPE" in variant["REPLS"] and "DST_TYPE" in variant["REPLS"]:
103+
output_name = f"{shader_base_name}_" + "_".join([variant["REPLS"]["SRC_TYPE"], variant["REPLS"]["DST_TYPE"]])
102104
elif "REPLS" in variant and "TYPE" in variant["REPLS"]:
103105
output_name = f"{shader_base_name}_" + variant["REPLS"]["TYPE"]
104106
else:

0 commit comments

Comments
 (0)