Skip to content

Commit 6355137

Browse files
committed
Work on templating for different types in shaders
1 parent 4ad0986 commit 6355137

File tree

4 files changed

+148
-42
lines changed

4 files changed

+148
-42
lines changed

ggml/src/ggml-webgpu/CMakeLists.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@ add_custom_command(
2020
COMMAND ${CMAKE_COMMAND} -E make_directory ${SHADER_OUTPUT_DIR}
2121
COMMAND ${CMAKE_COMMAND} -E env PYTHONIOENCODING=utf-8
2222
${Python3_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/wgsl-shaders/embed_wgsl.py
23-
--input "${SHADER_DIR}"
24-
--output "${SHADER_HEADER}"
23+
--input_dir "${SHADER_DIR}"
24+
--output_file "${SHADER_HEADER}"
2525
DEPENDS ${WGSL_SHADER_FILES} ${CMAKE_CURRENT_SOURCE_DIR}/wgsl-shaders/embed_wgsl.py
2626
VERBATIM
2727
)

ggml/src/ggml-webgpu/ggml-webgpu.cpp

Lines changed: 68 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,13 @@ static uint64_t webgpu_tensor_offset(const ggml_tensor * tensor) {
5050

5151
/* Struct definitions */
5252

53+
struct webgpu_pipeline_info {
54+
std::string name;
55+
const char * shader_code;
56+
ggml_type src0_type;
57+
ggml_type src1_type;
58+
};
59+
5360
// Forward reference
5461
static void ggml_webgpu_create_buffer(wgpu::Device & device,
5562
wgpu::Buffer & buffer,
@@ -124,7 +131,8 @@ struct webgpu_context_struct {
124131
webgpu_buf_pool set_rows_error_buf_pool;
125132

126133
wgpu::ComputePipeline memset_pipeline;
127-
wgpu::ComputePipeline mul_mat_pipeline;
134+
// [src0 0=fp32,1=fp16][src1 0=fp32,1=fp16]
135+
wgpu::ComputePipeline mul_mat_pipeline[2][2];
128136
wgpu::ComputePipeline set_rows_pipeline;
129137
wgpu::ComputePipeline cpy_pipeline;
130138

@@ -227,6 +235,15 @@ static void ggml_webgpu_create_buffer(wgpu::Device & device,
227235

228236
/** End WebGPU object initializations */
229237

238+
/** Utility Functions */
239+
240+
size_t ggml_webgpu_binding_size(ggml_tensor * t, size_t misalignment) {
241+
return (ggml_nbytes(t) + misalignment + WEBGPU_STORAGE_BUF_BINDING_MULT - 1) &
242+
~(WEBGPU_STORAGE_BUF_BINDING_MULT - 1);
243+
}
244+
245+
/** End Utility Functions */
246+
230247
/** WebGPU Actions */
231248

232249
// Wait for the queue to finish processing all submitted work
@@ -479,13 +496,11 @@ static void ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, ggml_tensor
479496
{ .binding = 0,
480497
.buffer = ggml_backend_webgpu_tensor_buf(src),
481498
.offset = src_offset,
482-
.size = (ggml_nbytes(src) + src_misalignment + WEBGPU_STORAGE_BUF_BINDING_MULT - 1) &
483-
~(WEBGPU_STORAGE_BUF_BINDING_MULT - 1) },
499+
.size = ggml_webgpu_binding_size(src, src_misalignment) },
484500
{ .binding = 1,
485501
.buffer = ggml_backend_webgpu_tensor_buf(dst),
486502
.offset = dst_offset,
487-
.size = (ggml_nbytes(dst) + dst_misalignment + WEBGPU_STORAGE_BUF_BINDING_MULT - 1) &
488-
~(WEBGPU_STORAGE_BUF_BINDING_MULT - 1) }
503+
.size = ggml_webgpu_binding_size(dst, dst_misalignment) }
489504
};
490505

491506
size_t max_wg_size = ctx->limits.maxComputeWorkgroupSizeX;
@@ -542,15 +557,15 @@ static void ggml_webgpu_set_rows(webgpu_context & ctx, ggml_tensor * src, ggml_t
542557
{ .binding = 0,
543558
.buffer = ggml_backend_webgpu_tensor_buf(src),
544559
.offset = ggml_backend_webgpu_tensor_offset(src),
545-
.size = ggml_nbytes(src) },
560+
.size = ggml_webgpu_binding_size(src, src_misalignment) },
546561
{ .binding = 1,
547562
.buffer = ggml_backend_webgpu_tensor_buf(idx),
548563
.offset = ggml_backend_webgpu_tensor_offset(idx),
549-
.size = ggml_nbytes(idx) },
564+
.size = ggml_webgpu_binding_size(idx, idx_misalignment) },
550565
{ .binding = 2,
551566
.buffer = ggml_backend_webgpu_tensor_buf(dst),
552567
.offset = ggml_backend_webgpu_tensor_offset(dst),
553-
.size = ggml_nbytes(dst) },
568+
.size = ggml_webgpu_binding_size(dst, dst_misalignment) },
554569
{ .binding = 3, .buffer = error_bufs.dev_buf, .offset = 0, .size = error_bufs.dev_buf.GetSize() }
555570
};
556571

@@ -564,7 +579,21 @@ static void ggml_webgpu_set_rows(webgpu_context & ctx, ggml_tensor * src, ggml_t
564579
}
565580

566581
static void ggml_webgpu_mul_mat(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst) {
582+
size_t src0_offset = ggml_backend_webgpu_tensor_offset(src0);
583+
size_t src0_misalignment = src0_offset & (ctx->limits.minStorageBufferOffsetAlignment - 1);
584+
// align to minimum offset alignment
585+
src0_offset &= ~(ctx->limits.minStorageBufferOffsetAlignment - 1);
586+
size_t src1_offset = ggml_backend_webgpu_tensor_offset(src1);
587+
size_t src1_misalignment = src1_offset & (ctx->limits.minStorageBufferOffsetAlignment - 1);
588+
src1_offset &= ~(ctx->limits.minStorageBufferOffsetAlignment - 1);
589+
size_t dst_offset = ggml_backend_webgpu_tensor_offset(dst);
590+
size_t dst_misalignment = dst_offset & (ctx->limits.minStorageBufferOffsetAlignment - 1);
591+
dst_offset &= ~(ctx->limits.minStorageBufferOffsetAlignment - 1);
592+
567593
std::vector<uint32_t> params = {
594+
(uint32_t) (src0_misalignment / ggml_type_size(src0->type)),
595+
(uint32_t) (src1_misalignment / ggml_type_size(src1->type)),
596+
(uint32_t) (dst_misalignment / ggml_type_size(dst->type)),
568597
(uint32_t) dst->ne[1], // number of rows in result (M)
569598
(uint32_t) dst->ne[0], // number of columns in result (N)
570599
(uint32_t) src0->ne[0], // number of columns in src0/src1 (K)
@@ -584,20 +613,20 @@ static void ggml_webgpu_mul_mat(webgpu_context & ctx, ggml_tensor * src0, ggml_t
584613
{ .binding = 0,
585614
.buffer = ggml_backend_webgpu_tensor_buf(src0),
586615
.offset = ggml_backend_webgpu_tensor_offset(src0),
587-
.size = ggml_nbytes(src0) },
616+
.size = ggml_webgpu_binding_size(src0, src0_misalignment) },
588617
{ .binding = 1,
589618
.buffer = ggml_backend_webgpu_tensor_buf(src1),
590619
.offset = ggml_backend_webgpu_tensor_offset(src1),
591-
.size = ggml_nbytes(src1) },
620+
.size = ggml_webgpu_binding_size(src1, src1_misalignment) },
592621
{ .binding = 2,
593622
.buffer = ggml_backend_webgpu_tensor_buf(dst),
594623
.offset = ggml_backend_webgpu_tensor_offset(dst),
595-
.size = ggml_nbytes(dst) }
624+
.size = ggml_webgpu_binding_size(dst, dst_misalignment) }
596625
};
597626

598627
uint32_t wg_x =
599628
(dst->ne[0] * dst->ne[1] * dst->ne[2] * dst->ne[3] + WEBGPU_MUL_MAT_WG_SIZE - 1) / WEBGPU_MUL_MAT_WG_SIZE;
600-
ggml_backend_webgpu_build_and_enqueue(ctx, ctx->mul_mat_pipeline, params, entries, wg_x);
629+
ggml_backend_webgpu_build_and_enqueue(ctx, ctx->mul_mat_pipeline[src0->type][src1->type], params, entries, wg_x);
601630
}
602631

603632
// Returns true if node has enqueued work into the queue, false otherwise
@@ -907,7 +936,31 @@ static void ggml_webgpu_init_memset_pipeline(webgpu_context & webgpu_ctx) {
907936
}
908937

909938
static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) {
910-
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_mat_pipeline, wgsl_mul_mat, "mul_mat");
939+
webgpu_pipeline_info pipeline_infos[4] = {
940+
{ .name = "mul_mat_f32_f32",
941+
.shader_code = wgsl_mul_mat_f32_f32,
942+
.src0_type = GGML_TYPE_F32,
943+
.src1_type = GGML_TYPE_F32 },
944+
{ .name = "mul_mat_f16_f16",
945+
.shader_code = wgsl_mul_mat_f16_f16,
946+
.src0_type = GGML_TYPE_F16,
947+
.src1_type = GGML_TYPE_F16 },
948+
{ .name = "mul_mat_f32_f16",
949+
.shader_code = wgsl_mul_mat_f32_f16,
950+
.src0_type = GGML_TYPE_F32,
951+
.src1_type = GGML_TYPE_F16 },
952+
{ .name = "mul_mat_f16_f32",
953+
.shader_code = wgsl_mul_mat_f16_f32,
954+
.src0_type = GGML_TYPE_F16,
955+
.src1_type = GGML_TYPE_F32 }
956+
};
957+
958+
for (auto & pipeline_info : pipeline_infos) {
959+
ggml_webgpu_create_pipeline(webgpu_ctx->device,
960+
webgpu_ctx->mul_mat_pipeline[pipeline_info.src0_type][pipeline_info.src1_type],
961+
pipeline_info.shader_code,
962+
pipeline_info.name.data());
963+
}
911964
}
912965

913966
static void ggml_webgpu_init_set_rows_pipeline(webgpu_context & webgpu_ctx) {
@@ -1056,7 +1109,8 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
10561109
case GGML_OP_CPY | GGML_OP_SET_ROWS:
10571110
return op->type == GGML_TYPE_F16 && op->src[0]->type == GGML_TYPE_F32;
10581111
case GGML_OP_MUL_MAT:
1059-
return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32;
1112+
return (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
1113+
(op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_F16);
10601114
default:
10611115
return false;
10621116
}
Lines changed: 47 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,60 @@
11
import os
2+
import re
3+
import ast
24
import argparse
35

4-
5-
def escape_triple_quotes(wgsl):
6-
# Simple defense in case of embedded """
7-
return wgsl.replace('"""', '\\"""')
8-
9-
10-
def to_cpp_string_literal(varname, content):
11-
return f'const char* wgsl_{varname} = R"({content})";\n'
12-
6+
variants_regex = re.compile(r'//\s*Variants:\s*\n(\[.*?\])', re.DOTALL)
7+
8+
def remove_variants_block(template_text):
9+
return re.sub(variants_regex, '', template_text)
10+
11+
def extract_variants(template_text):
12+
match = re.search(variants_regex, template_text)
13+
if not match:
14+
return None
15+
return ast.literal_eval(match.group(1))
16+
17+
def write_shader(shader_name, shader_code, output_dir, outfile):
18+
if output_dir:
19+
wgsl_filename = os.path.join(output_dir, f"{shader_name}.wgsl")
20+
with open(wgsl_filename, 'w', encoding='utf-8') as f_out:
21+
f_out.write(shader_code)
22+
outfile.write(f'const char* wgsl_{shader_name} = R"({shader_code})";\n')
23+
outfile.write('\n')
24+
25+
def generate_variants(shader_path, output_dir, outfile):
26+
shader_base_name = shader_path.split("/")[-1].split(".")[0]
27+
with open(shader_path, 'r', encoding='utf-8') as f:
28+
shader_code = f.read()
29+
variants = extract_variants(shader_code)
30+
shader_code = remove_variants_block(shader_code)
31+
if not variants:
32+
write_shader(shader_base_name, shader_code, output_dir, outfile)
33+
else:
34+
for variant in variants:
35+
shader_variant = shader_code
36+
parts = []
37+
for key, val in variant.items():
38+
parts.append(val)
39+
shader_variant = shader_variant.replace(key, val)
40+
output_name = f"{shader_base_name}_" + "_".join(parts)
41+
write_shader(output_name, shader_variant, output_dir, outfile)
1342

1443
def main():
1544
parser = argparse.ArgumentParser()
16-
parser.add_argument('--input', required=True)
17-
parser.add_argument('--output', required=True)
45+
parser.add_argument('--input_dir', required=True)
46+
parser.add_argument('--output_file', required=True)
47+
parser.add_argument('--output_dir')
1848
args = parser.parse_args()
19-
20-
with open(args.output, 'w', encoding='utf-8') as out:
49+
if args.output_dir:
50+
os.makedirs(args.output_dir, exist_ok=True)
51+
with open(args.output_file, 'w', encoding='utf-8') as out:
2152
out.write("// Auto-generated shader embedding \n\n")
22-
for fname in sorted(os.listdir(args.input)):
53+
for fname in sorted(os.listdir(args.input_dir)):
2354
if not fname.endswith('.wgsl'):
2455
continue
25-
shader_path = os.path.join(args.input, fname)
26-
varname = os.path.splitext(fname)[0]
27-
with open(shader_path, 'r', encoding='utf-8') as f:
28-
content = f.read()
29-
content = escape_triple_quotes(content)
30-
out.write(to_cpp_string_literal(varname, content))
31-
out.write('\n')
32-
56+
shader_path = os.path.join(args.input_dir, fname)
57+
generate_variants(shader_path, args.output_dir, out)
3358

3459
if __name__ == '__main__':
3560
main()

ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl renamed to ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,30 @@
1+
// Variants:
2+
[
3+
{
4+
"SRC0_TYPE" : "f32",
5+
"SRC1_TYPE" : "f32"
6+
},
7+
{
8+
"SRC0_TYPE" : "f16",
9+
"SRC1_TYPE" : "f16"
10+
},
11+
{
12+
"SRC0_TYPE" : "f16",
13+
"SRC1_TYPE" : "f32"
14+
},
15+
{
16+
"SRC0_TYPE" : "f32",
17+
"SRC1_TYPE" : "f16"
18+
}
19+
]
20+
21+
// Shader Template:
22+
enable f16;
23+
124
struct MulMatParams {
25+
offset_src0: u32, // in elements
26+
offset_src1: u32, // in elements
27+
offset_dst: u32, // in elements
228
m: u32,
329
n: u32,
430
k: u32,
@@ -16,8 +42,8 @@ struct MulMatParams {
1642
broadcast3: u32
1743
};
1844

19-
@group(0) @binding(0) var<storage, read_write> src0: array<f32>; // N rows, K columns
20-
@group(0) @binding(1) var<storage, read_write> src1: array<f32>; // M rows, K columns (transposed)
45+
@group(0) @binding(0) var<storage, read_write> src0: array<SRC0_TYPE>; // N rows, K columns
46+
@group(0) @binding(1) var<storage, read_write> src1: array<SRC1_TYPE>; // M rows, K columns (transposed)
2147
@group(0) @binding(2) var<storage, read_write> dst: array<f32>; // M rows, N columns
2248

2349
@group(0) @binding(3) var<uniform> params: MulMatParams;
@@ -50,7 +76,8 @@ fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
5076
for (var i: u32 = 0u; i < params.k; i = i + 1u) {
5177
let src0_idx = src03_idx * params.stride_03 + src02_idx * params.stride_02 + col * params.stride_01 + i;
5278
let src1_idx = src13_idx * params.stride_13 + src12_idx * params.stride_12 + row * params.stride_11 + i;
53-
sum = sum + src0[src0_idx] * src1[src1_idx];
79+
sum = sum + f32(src0[params.offset_src0 + src0_idx]) * f32(src1[params.offset_src1 + src1_idx]);
5480
}
55-
dst[dst3_idx * dst3_stride + dst2_idx * dst2_stride + row * params.n + col] = sum;
81+
dst[params.offset_dst + dst3_idx * dst3_stride + dst2_idx * dst2_stride + row * params.n + col] = sum;
82+
5683
}

0 commit comments

Comments
 (0)