Skip to content

Commit 51252f0

Browse files
committed
Cleanup code
1 parent 7a2ae48 commit 51252f0

File tree

4 files changed

+103
-133
lines changed

4 files changed

+103
-133
lines changed

ggml/src/ggml-webgpu/ggml-webgpu.cpp

Lines changed: 94 additions & 115 deletions
Original file line numberDiff line numberDiff line change
@@ -50,13 +50,6 @@ static uint64_t webgpu_tensor_offset(const ggml_tensor * tensor) {
5050

5151
/* Struct definitions */
5252

53-
struct webgpu_pipeline_info {
54-
std::string name;
55-
const char * shader_code;
56-
ggml_type src0_type;
57-
ggml_type src1_type;
58-
};
59-
6053
// Forward reference
6154
static void ggml_webgpu_create_buffer(wgpu::Device & device,
6255
wgpu::Buffer & buffer,
@@ -571,12 +564,12 @@ static void ggml_webgpu_mul_mat(webgpu_context & ctx, ggml_tensor * src0, ggml_t
571564
(uint32_t) dst->ne[1], // number of rows in result (M)
572565
(uint32_t) dst->ne[0], // number of columns in result (N)
573566
(uint32_t) src0->ne[0], // number of columns in src0/src1 (K)
574-
(uint32_t) (src0->nb[1] / ggml_type_size(src0->type)), // stride (elements) of src0 in dimension 1
575-
(uint32_t) (src1->nb[1] / ggml_type_size(src1->type)), // stride (elements) of src1 in dimension 1
576-
(uint32_t) (src0->nb[2] / ggml_type_size(src0->type)), // stride (elements) of src0 in dimension 2
577-
(uint32_t) (src1->nb[2] / ggml_type_size(src1->type)), // stride (elements) of src1 in dimension 2
578-
(uint32_t) (src0->nb[3] / ggml_type_size(src0->type)), // stride (elements) of src0 in dimension 3
579-
(uint32_t) (src1->nb[3] / ggml_type_size(src1->type)), // stride (elements) of src1 in dimension 3
567+
(uint32_t) (src0->nb[1] / ggml_type_size(src0->type)), // stride (elements/blocks) of src0 in dimension 1
568+
(uint32_t) (src1->nb[1] / ggml_type_size(src1->type)), // stride (elements/blocks) of src1 in dimension 1
569+
(uint32_t) (src0->nb[2] / ggml_type_size(src0->type)), // stride (elements/blocks) of src0 in dimension 2
570+
(uint32_t) (src1->nb[2] / ggml_type_size(src1->type)), // stride (elements/blocks) of src1 in dimension 2
571+
(uint32_t) (src0->nb[3] / ggml_type_size(src0->type)), // stride (elements/blocks) of src0 in dimension 3
572+
(uint32_t) (src1->nb[3] / ggml_type_size(src1->type)), // stride (elements/blocks) of src1 in dimension 3
580573
(uint32_t) src0->ne[2], // batch size in dimension 2
581574
(uint32_t) src0->ne[3], // batch size in dimension 3
582575
(uint32_t) (src1->ne[2] / src0->ne[2]), // broadcast in dimension 2
@@ -596,16 +589,11 @@ static void ggml_webgpu_mul_mat(webgpu_context & ctx, ggml_tensor * src0, ggml_t
596589
.buffer = ggml_webgpu_tensor_buf(dst),
597590
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
598591
.size = ggml_webgpu_tensor_binding_size(ctx, dst) },
599-
// { .binding = 3,
600-
// .buffer = ctx->debug_dev_buf,
601-
// .offset = 0,
602-
// .size = ctx->debug_dev_buf.GetSize() }
603592
};
604593

605594
uint32_t wg_x =
606595
(dst->ne[0] * dst->ne[1] * dst->ne[2] * dst->ne[3] + WEBGPU_MUL_MAT_WG_SIZE - 1) / WEBGPU_MUL_MAT_WG_SIZE;
607596
ggml_backend_webgpu_build_and_enqueue(ctx, ctx->mul_mat_pipeline[src0->type][src1->type], params, entries, wg_x);
608-
//ggml_backend_webgpu_debug(ctx);
609597
}
610598

611599
// Returns true if node has enqueued work into the queue, false otherwise
@@ -915,103 +903,94 @@ static void ggml_webgpu_init_memset_pipeline(webgpu_context & webgpu_ctx) {
915903
}
916904

917905
static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) {
918-
webgpu_pipeline_info pipeline_infos[22] = {
919-
{ .name = "mul_mat_f32_f32",
920-
.shader_code = wgsl_mul_mat_f32_f32,
921-
.src0_type = GGML_TYPE_F32,
922-
.src1_type = GGML_TYPE_F32 },
923-
{ .name = "mul_mat_f16_f16",
924-
.shader_code = wgsl_mul_mat_f16_f16,
925-
.src0_type = GGML_TYPE_F16,
926-
.src1_type = GGML_TYPE_F16 },
927-
{ .name = "mul_mat_f16_f32",
928-
.shader_code = wgsl_mul_mat_f16_f32,
929-
.src0_type = GGML_TYPE_F16,
930-
.src1_type = GGML_TYPE_F32 },
931-
{ .name = "mul_mat_q4_0_f32",
932-
.shader_code = wgsl_mul_mat_q4_0_f32,
933-
.src0_type = GGML_TYPE_Q4_0,
934-
.src1_type = GGML_TYPE_F32 },
935-
{ .name = "mul_mat_q4_1_f32",
936-
.shader_code = wgsl_mul_mat_q4_1_f32,
937-
.src0_type = GGML_TYPE_Q4_1,
938-
.src1_type = GGML_TYPE_F32 },
939-
{ .name = "mul_mat_q5_0_f32",
940-
.shader_code = wgsl_mul_mat_q5_0_f32,
941-
.src0_type = GGML_TYPE_Q5_0,
942-
.src1_type = GGML_TYPE_F32 },
943-
{ .name = "mul_mat_q5_1_f32",
944-
.shader_code = wgsl_mul_mat_q5_1_f32,
945-
.src0_type = GGML_TYPE_Q5_1,
946-
.src1_type = GGML_TYPE_F32 },
947-
{ .name = "mul_mat_q8_0_f32",
948-
.shader_code = wgsl_mul_mat_q8_0_f32,
949-
.src0_type = GGML_TYPE_Q8_0,
950-
.src1_type = GGML_TYPE_F32 },
951-
{ .name = "mul_mat_q2_k_f32",
952-
.shader_code = wgsl_mul_mat_q2_k_f32,
953-
.src0_type = GGML_TYPE_Q2_K,
954-
.src1_type = GGML_TYPE_F32 },
955-
{ .name = "mul_mat_q3_k_f32",
956-
.shader_code = wgsl_mul_mat_q3_k_f32,
957-
.src0_type = GGML_TYPE_Q3_K,
958-
.src1_type = GGML_TYPE_F32 },
959-
{ .name = "mul_mat_q4_k_f32",
960-
.shader_code = wgsl_mul_mat_q4_k_f32,
961-
.src0_type = GGML_TYPE_Q4_K,
962-
.src1_type = GGML_TYPE_F32 },
963-
{ .name = "mul_mat_q5_k_f32",
964-
.shader_code = wgsl_mul_mat_q5_k_f32,
965-
.src0_type = GGML_TYPE_Q5_K,
966-
.src1_type = GGML_TYPE_F32 },
967-
{ .name = "mul_mat_q6_k_f32",
968-
.shader_code = wgsl_mul_mat_q6_k_f32,
969-
.src0_type = GGML_TYPE_Q6_K,
970-
.src1_type = GGML_TYPE_F32 },
971-
{ .name = "mul_mat_iq2_xxs_f32",
972-
.shader_code = wgsl_mul_mat_iq2_xxs_f32,
973-
.src0_type = GGML_TYPE_IQ2_XXS,
974-
.src1_type = GGML_TYPE_F32 },
975-
{ .name = "mul_mat_iq2_xs_f32",
976-
.shader_code = wgsl_mul_mat_iq2_xs_f32,
977-
.src0_type = GGML_TYPE_IQ2_XS,
978-
.src1_type = GGML_TYPE_F32 },
979-
{ .name = "mul_mat_iq2_s_f32",
980-
.shader_code = wgsl_mul_mat_iq2_s_f32,
981-
.src0_type = GGML_TYPE_IQ2_S,
982-
.src1_type = GGML_TYPE_F32 },
983-
{ .name = "mul_mat_iq3_xxs_f32",
984-
.shader_code = wgsl_mul_mat_iq3_xxs_f32,
985-
.src0_type = GGML_TYPE_IQ3_XXS,
986-
.src1_type = GGML_TYPE_F32 },
987-
{ .name = "mul_mat_iq3_s_f32",
988-
.shader_code = wgsl_mul_mat_iq3_s_f32,
989-
.src0_type = GGML_TYPE_IQ3_S,
990-
.src1_type = GGML_TYPE_F32 },
991-
{ .name = "mul_mat_iq1_s_f32",
992-
.shader_code = wgsl_mul_mat_iq1_s_f32,
993-
.src0_type = GGML_TYPE_IQ1_S,
994-
.src1_type = GGML_TYPE_F32 },
995-
{ .name = "mul_mat_iq1_m_f32",
996-
.shader_code = wgsl_mul_mat_iq1_m_f32,
997-
.src0_type = GGML_TYPE_IQ1_M,
998-
.src1_type = GGML_TYPE_F32 },
999-
{ .name = "mul_mat_iq4_nl_f32",
1000-
.shader_code = wgsl_mul_mat_iq4_nl_f32,
1001-
.src0_type = GGML_TYPE_IQ4_NL,
1002-
.src1_type = GGML_TYPE_F32 },
1003-
{ .name = "mul_mat_iq4_xs_f32",
1004-
.shader_code = wgsl_mul_mat_iq4_xs_f32,
1005-
.src0_type = GGML_TYPE_IQ4_XS,
1006-
.src1_type = GGML_TYPE_F32 }
1007-
};
1008-
1009-
for (auto & pipeline_info : pipeline_infos) {
1010-
ggml_webgpu_create_pipeline(webgpu_ctx->device,
1011-
webgpu_ctx->mul_mat_pipeline[pipeline_info.src0_type][pipeline_info.src1_type],
1012-
pipeline_info.shader_code,
1013-
pipeline_info.name.data());
1014-
}
906+
ggml_webgpu_create_pipeline(webgpu_ctx->device,
907+
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_F32][GGML_TYPE_F32],
908+
wgsl_mul_mat_f32_f32,
909+
"mul_mat_f32_f32");
910+
ggml_webgpu_create_pipeline(webgpu_ctx->device,
911+
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_F16][GGML_TYPE_F16],
912+
wgsl_mul_mat_f16_f16,
913+
"mul_mat_f16_f16");
914+
ggml_webgpu_create_pipeline(webgpu_ctx->device,
915+
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_F16][GGML_TYPE_F32],
916+
wgsl_mul_mat_f16_f32,
917+
"mul_mat_f16_f32");
918+
ggml_webgpu_create_pipeline(webgpu_ctx->device,
919+
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q4_0][GGML_TYPE_F32],
920+
wgsl_mul_mat_q4_0_f32,
921+
"mul_mat_q4_0_f32");
922+
ggml_webgpu_create_pipeline(webgpu_ctx->device,
923+
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q4_1][GGML_TYPE_F32],
924+
wgsl_mul_mat_q4_1_f32,
925+
"mul_mat_q4_1_f32");
926+
ggml_webgpu_create_pipeline(webgpu_ctx->device,
927+
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q5_0][GGML_TYPE_F32],
928+
wgsl_mul_mat_q5_0_f32,
929+
"mul_mat_q5_0_f32");
930+
ggml_webgpu_create_pipeline(webgpu_ctx->device,
931+
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q5_1][GGML_TYPE_F32],
932+
wgsl_mul_mat_q5_1_f32,
933+
"mul_mat_q5_1_f32");
934+
ggml_webgpu_create_pipeline(webgpu_ctx->device,
935+
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q8_0][GGML_TYPE_F32],
936+
wgsl_mul_mat_q8_0_f32,
937+
"mul_mat_q8_0_f32");
938+
ggml_webgpu_create_pipeline(webgpu_ctx->device,
939+
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q2_K][GGML_TYPE_F32],
940+
wgsl_mul_mat_q2_k_f32,
941+
"mul_mat_q2_k_f32");
942+
ggml_webgpu_create_pipeline(webgpu_ctx->device,
943+
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q3_K][GGML_TYPE_F32],
944+
wgsl_mul_mat_q3_k_f32,
945+
"mul_mat_q3_k_f32");
946+
ggml_webgpu_create_pipeline(webgpu_ctx->device,
947+
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q4_K][GGML_TYPE_F32],
948+
wgsl_mul_mat_q4_k_f32,
949+
"mul_mat_q4_k_f32");
950+
ggml_webgpu_create_pipeline(webgpu_ctx->device,
951+
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q5_K][GGML_TYPE_F32],
952+
wgsl_mul_mat_q5_k_f32,
953+
"mul_mat_q5_k_f32");
954+
ggml_webgpu_create_pipeline(webgpu_ctx->device,
955+
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_Q6_K][GGML_TYPE_F32],
956+
wgsl_mul_mat_q6_k_f32,
957+
"mul_mat_q6_k_f32");
958+
ggml_webgpu_create_pipeline(webgpu_ctx->device,
959+
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ2_XXS][GGML_TYPE_F32],
960+
wgsl_mul_mat_iq2_xxs_f32,
961+
"mul_mat_iq2_xxs_f32");
962+
ggml_webgpu_create_pipeline(webgpu_ctx->device,
963+
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ2_XS][GGML_TYPE_F32],
964+
wgsl_mul_mat_iq2_xs_f32,
965+
"mul_mat_iq2_xs_f32");
966+
ggml_webgpu_create_pipeline(webgpu_ctx->device,
967+
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ2_S][GGML_TYPE_F32],
968+
wgsl_mul_mat_iq2_s_f32,
969+
"mul_mat_iq2_s_f32");
970+
ggml_webgpu_create_pipeline(webgpu_ctx->device,
971+
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ3_XXS][GGML_TYPE_F32],
972+
wgsl_mul_mat_iq3_xxs_f32,
973+
"mul_mat_iq3_xxs_f32");
974+
ggml_webgpu_create_pipeline(webgpu_ctx->device,
975+
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ3_S][GGML_TYPE_F32],
976+
wgsl_mul_mat_iq3_s_f32,
977+
"mul_mat_iq3_s_f32");
978+
ggml_webgpu_create_pipeline(webgpu_ctx->device,
979+
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ1_S][GGML_TYPE_F32],
980+
wgsl_mul_mat_iq1_s_f32,
981+
"mul_mat_iq1_s_f32");
982+
ggml_webgpu_create_pipeline(webgpu_ctx->device,
983+
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ1_M][GGML_TYPE_F32],
984+
wgsl_mul_mat_iq1_m_f32,
985+
"mul_mat_iq1_m_f32");
986+
ggml_webgpu_create_pipeline(webgpu_ctx->device,
987+
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ4_NL][GGML_TYPE_F32],
988+
wgsl_mul_mat_iq4_nl_f32,
989+
"mul_mat_iq4_nl_f32");
990+
ggml_webgpu_create_pipeline(webgpu_ctx->device,
991+
webgpu_ctx->mul_mat_pipeline[GGML_TYPE_IQ4_XS][GGML_TYPE_F32],
992+
wgsl_mul_mat_iq4_xs_f32,
993+
"mul_mat_iq4_xs_f32");
1015994
}
1016995

1017996
static void ggml_webgpu_init_set_rows_pipeline(webgpu_context & webgpu_ctx) {

ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@ def parse_decls(decls_text):
1818

1919
def replace_placeholders(shader_text, replacements):
2020
for key, val in replacements.items():
21-
pattern = rf'\b{re.escape(key)}\b'
21+
# Match {{KEY}} literally, where KEY is escaped
22+
pattern = r'{{\s*' + re.escape(key) + r'\s*}}'
2223
shader_text = re.sub(pattern, str(val), shader_text)
2324
return shader_text
2425

ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1731,13 +1731,13 @@ enable f16;
17311731
DECLS
17321732

17331733
struct MulMatParams {
1734-
offset_src0: u32, // in elements
1735-
offset_src1: u32, // in elements
1736-
offset_dst: u32, // in elements
1734+
offset_src0: u32, // in elements/blocks
1735+
offset_src1: u32, // in elements/blocks
1736+
offset_dst: u32, // in elements/blocks
17371737
m: u32,
17381738
n: u32,
17391739
k: u32,
1740-
// all strides are in elements
1740+
// all strides are in elements/blocks
17411741
stride_01: u32,
17421742
stride_11: u32,
17431743
stride_02: u32,
@@ -1751,10 +1751,9 @@ struct MulMatParams {
17511751
broadcast3: u32
17521752
};
17531753

1754-
@group(0) @binding(0) var<storage, read_write> src0: array<SRC0_TYPE>; // N rows, K columns
1755-
@group(0) @binding(1) var<storage, read_write> src1: array<SRC1_TYPE>; // M rows, K columns (transposed)
1754+
@group(0) @binding(0) var<storage, read_write> src0: array<{{SRC0_TYPE}}>; // N rows, K columns
1755+
@group(0) @binding(1) var<storage, read_write> src1: array<{{SRC1_TYPE}}>; // M rows, K columns (transposed)
17561756
@group(0) @binding(2) var<storage, read_write> dst: array<f32>; // M rows, N columns
1757-
//@group(0) @binding(3) var<storage, read_write> debug: array<f32>;
17581757

17591758
@group(0) @binding(3) var<uniform> params: MulMatParams;
17601759

@@ -1786,7 +1785,7 @@ fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
17861785
let src1_idx_base = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12 + row * params.stride_11;
17871786

17881787
var sum = 0.0;
1789-
for (var i: u32 = 0u; i < params.k/BLOCK_SIZE; i = i + 1u) {
1788+
for (var i: u32 = 0u; i < params.k/{{BLOCK_SIZE}}; i = i + 1u) {
17901789
sum += multiply_add(src0_idx_base, src1_idx_base, i);
17911790
}
17921791
dst[params.offset_dst + dst3_idx * dst3_stride + dst2_idx * dst2_stride + row * params.n + col] = sum;

tests/test-backend-ops.cpp

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -979,15 +979,6 @@ struct test_case {
979979
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != nullptr; t = ggml_get_next_tensor(ctx, t)) {
980980
init_tensor_uniform(t);
981981
}
982-
// print first 32 elements of each tensor
983-
// for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != nullptr; t = ggml_get_next_tensor(ctx, t)) {
984-
// if (strcmp(ggml_get_name(t), "a") == 0) {
985-
// std::vector<float> values = tensor_to_float(t);
986-
// for (int i = 0; i < 32; i++) {
987-
// printf("%s[%d] = %f\n", ggml_get_name(t), i, values[i]);
988-
// }
989-
// }
990-
// }
991982
}
992983

993984
virtual size_t op_size(ggml_tensor * t) {

0 commit comments

Comments
 (0)