Skip to content

Commit c3611f9

Browse files
committed
Add matmul support for basic quantization types
1 parent 1aa40f1 commit c3611f9

File tree

2 files changed

+281
-8
lines changed

2 files changed

+281
-8
lines changed

ggml/src/ggml-webgpu/ggml-webgpu.cpp

Lines changed: 65 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ struct webgpu_context_struct {
129129
webgpu_buf_pool set_rows_error_buf_pool;
130130

131131
wgpu::ComputePipeline memset_pipeline;
132-
wgpu::ComputePipeline mul_mat_pipeline[3][2];
132+
wgpu::ComputePipeline mul_mat_pipeline[10][2];
133133
wgpu::ComputePipeline set_rows_pipeline;
134134
wgpu::ComputePipeline cpy_pipeline;
135135

@@ -910,7 +910,7 @@ static void ggml_webgpu_init_memset_pipeline(webgpu_context & webgpu_ctx) {
910910
}
911911

912912
static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) {
913-
webgpu_pipeline_info pipeline_infos[6] = {
913+
webgpu_pipeline_info pipeline_infos[16] = {
914914
{ .name = "mul_mat_f32_f32",
915915
.shader_code = wgsl_mul_mat_f32_f32,
916916
.src0_type = GGML_TYPE_F32,
@@ -934,6 +934,46 @@ static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) {
934934
{ .name = "mul_mat_q4_0_f16",
935935
.shader_code = wgsl_mul_mat_q4_0_f16,
936936
.src0_type = GGML_TYPE_Q4_0,
937+
.src1_type = GGML_TYPE_F16 },
938+
{ .name = "mul_mat_q4_1_f32",
939+
.shader_code = wgsl_mul_mat_q4_1_f32,
940+
.src0_type = GGML_TYPE_Q4_1,
941+
.src1_type = GGML_TYPE_F32 },
942+
{ .name = "mul_mat_q4_1_f16",
943+
.shader_code = wgsl_mul_mat_q4_1_f16,
944+
.src0_type = GGML_TYPE_Q4_1,
945+
.src1_type = GGML_TYPE_F16 },
946+
{ .name = "mul_mat_q5_0_f32",
947+
.shader_code = wgsl_mul_mat_q5_0_f32,
948+
.src0_type = GGML_TYPE_Q5_0,
949+
.src1_type = GGML_TYPE_F32 },
950+
{ .name = "mul_mat_q5_0_f16",
951+
.shader_code = wgsl_mul_mat_q5_0_f16,
952+
.src0_type = GGML_TYPE_Q5_0,
953+
.src1_type = GGML_TYPE_F16 },
954+
{ .name = "mul_mat_q5_1_f32",
955+
.shader_code = wgsl_mul_mat_q5_1_f32,
956+
.src0_type = GGML_TYPE_Q5_1,
957+
.src1_type = GGML_TYPE_F32 },
958+
{ .name = "mul_mat_q5_1_f16",
959+
.shader_code = wgsl_mul_mat_q5_1_f16,
960+
.src0_type = GGML_TYPE_Q5_1,
961+
.src1_type = GGML_TYPE_F16 },
962+
{ .name = "mul_mat_q8_0_f32",
963+
.shader_code = wgsl_mul_mat_q8_0_f32,
964+
.src0_type = GGML_TYPE_Q8_0,
965+
.src1_type = GGML_TYPE_F32 },
966+
{ .name = "mul_mat_q8_0_f16",
967+
.shader_code = wgsl_mul_mat_q8_0_f16,
968+
.src0_type = GGML_TYPE_Q8_0,
969+
.src1_type = GGML_TYPE_F16 },
970+
{ .name = "mul_mat_q8_1_f32",
971+
.shader_code = wgsl_mul_mat_q8_1_f32,
972+
.src0_type = GGML_TYPE_Q8_1,
973+
.src1_type = GGML_TYPE_F32 },
974+
{ .name = "mul_mat_q8_1_f16",
975+
.shader_code = wgsl_mul_mat_q8_1_f16,
976+
.src0_type = GGML_TYPE_Q8_1,
937977
.src1_type = GGML_TYPE_F16 }
938978
};
939979

@@ -1015,12 +1055,31 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
10151055
case GGML_OP_VIEW:
10161056
case GGML_OP_PERMUTE:
10171057
return true;
1018-
case GGML_OP_CPY | GGML_OP_SET_ROWS:
1058+
case GGML_OP_CPY:
1059+
case GGML_OP_SET_ROWS:
10191060
return op->type == GGML_TYPE_F16 && op->src[0]->type == GGML_TYPE_F32;
10201061
case GGML_OP_MUL_MAT:
1021-
return (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16 ||
1022-
op->src[0]->type == GGML_TYPE_Q4_0) &&
1023-
(op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_F16);
1062+
switch(op->src[0]->type) {
1063+
case GGML_TYPE_F32:
1064+
case GGML_TYPE_F16:
1065+
case GGML_TYPE_Q4_0:
1066+
case GGML_TYPE_Q4_1:
1067+
case GGML_TYPE_Q5_0:
1068+
case GGML_TYPE_Q5_1:
1069+
case GGML_TYPE_Q8_0:
1070+
case GGML_TYPE_Q8_1:
1071+
break;
1072+
default:
1073+
return false;
1074+
}
1075+
switch(op->src[1]->type) {
1076+
case GGML_TYPE_F32:
1077+
case GGML_TYPE_F16:
1078+
break;
1079+
default:
1080+
return false;
1081+
}
1082+
return true;
10241083
default:
10251084
return false;
10261085
}

ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl

Lines changed: 216 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,87 @@
4848
"BLOCK_SIZE": 32
4949
},
5050
"DECLS": "Q4_0"
51+
},
52+
{
53+
"REPLS": {
54+
"SRC0_TYPE": "q4_1",
55+
"SRC1_TYPE": "f32",
56+
"BLOCK_SIZE": 32
57+
},
58+
"DECLS": "Q4_1"
59+
},
60+
{
61+
"REPLS": {
62+
"SRC0_TYPE": "q4_1",
63+
"SRC1_TYPE": "f16",
64+
"BLOCK_SIZE": 32
65+
},
66+
"DECLS": "Q4_1"
67+
},
68+
{
69+
"REPLS": {
70+
"SRC0_TYPE": "q5_0",
71+
"SRC1_TYPE": "f32",
72+
"BLOCK_SIZE": 32
73+
},
74+
"DECLS": "Q5_0"
75+
},
76+
{
77+
"REPLS": {
78+
"SRC0_TYPE": "q5_0",
79+
"SRC1_TYPE": "f16",
80+
"BLOCK_SIZE": 32
81+
},
82+
"DECLS": "Q5_0"
83+
},
84+
{
85+
"REPLS": {
86+
"SRC0_TYPE": "q5_1",
87+
"SRC1_TYPE": "f32",
88+
"BLOCK_SIZE": 32
89+
},
90+
"DECLS": "Q5_1"
91+
},
92+
{
93+
"REPLS": {
94+
"SRC0_TYPE": "q5_1",
95+
"SRC1_TYPE": "f16",
96+
"BLOCK_SIZE": 32
97+
},
98+
"DECLS": "Q5_1"
99+
},
100+
{
101+
"REPLS": {
102+
"SRC0_TYPE": "q8_0",
103+
"SRC1_TYPE": "f32",
104+
"BLOCK_SIZE": 32
105+
},
106+
"DECLS": "Q8_0"
107+
},
108+
{
109+
"REPLS": {
110+
"SRC0_TYPE": "q8_0",
111+
"SRC1_TYPE": "f16",
112+
"BLOCK_SIZE": 32
113+
},
114+
"DECLS": "Q8_0"
115+
},
116+
{
117+
"REPLS": {
118+
"SRC0_TYPE": "q8_1",
119+
"SRC1_TYPE": "f32",
120+
"BLOCK_SIZE": 32
121+
},
122+
"DECLS": "Q8_1"
123+
},
124+
{
125+
"REPLS": {
126+
"SRC0_TYPE": "q8_1",
127+
"SRC1_TYPE": "f16",
128+
"BLOCK_SIZE": 32
129+
},
130+
"DECLS": "Q8_1"
51131
}
52-
53132
]
54133

55134
#end(VARIANTS)
@@ -69,7 +148,7 @@ struct q4_0 {
69148
};
70149

71150
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
72-
let block_q4_0: q4_0 = src0[src0_idx_base + offset];
151+
let block_q4_0 = src0[src0_idx_base + offset];
73152
let d = f32(block_q4_0.d);
74153
var sum: f32 = 0.0;
75154
for (var j: u32 = 0; j < 4; j++) {
@@ -86,6 +165,141 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
86165
return sum;
87166
}
88167
#enddecl(Q4_0)
168+
169+
#decl(Q4_1)
170+
struct q4_1 {
171+
d: f16,
172+
m: f16,
173+
qs: array<u32, 4>
174+
};
175+
176+
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
177+
let block_q4_1 = src0[src0_idx_base + offset];
178+
let d = f32(block_q4_1.d);
179+
let m = f32(block_q4_1.m);
180+
var sum: f32 = 0.0;
181+
for (var j: u32 = 0; j < 4; j++) {
182+
let q_packed = block_q4_1.qs[j];
183+
for (var k: u32 = 0; k < 4; k++) {
184+
let q_byte = (q_packed >> (k * 8)) & 0xFF;
185+
let q_hi = f32((q_byte >> 4) & 0xF) * d + m;
186+
let q_lo = f32(q_byte & 0xF) * d + m;
187+
let src1_offset = src1_idx_base + offset * 32 + j * 4 + k;
188+
sum += q_lo * f32(src1[src1_offset]);
189+
sum += q_hi * f32(src1[src1_offset + 16]);
190+
}
191+
}
192+
return sum;
193+
}
194+
#enddecl(Q4_1)
195+
196+
#decl(Q5_0)
197+
struct q5_0 {
198+
d: f16,
199+
qh: array<f16, 2>,
200+
qs: array<f16, 8>
201+
};
202+
203+
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
204+
let block_q5_0 = src0[src0_idx_base + offset];
205+
let d = f32(block_q5_0.d);
206+
var sum: f32 = 0.0;
207+
let qh_packed = bitcast<u32>(vec2(block_q5_0.qh[0], block_q5_0.qh[1]));
208+
for (var j: u32 = 0; j < 4; j++) {
209+
let q_packed = bitcast<u32>(vec2(block_q5_0.qs[2 * j], block_q5_0.qs[2 * j + 1]));
210+
for (var k: u32 = 0; k < 4; k++) {
211+
let q_byte = (q_packed >> (k * 8)) & 0xFF;
212+
let qh_hi = (qh_packed >> (j * 4 + k + 12)) & 0x10;
213+
let q_hi = (f32(((q_byte >> 4) & 0xF) | qh_hi) - 16.0) * d;
214+
let qh_lo = ((qh_packed >> (j * 4 + k)) << 4) & 0x10;
215+
let q_lo = (f32((q_byte & 0xF) | qh_lo) - 16.0) * d;
216+
let src1_offset = src1_idx_base + offset * 32 + j * 4 + k;
217+
sum += q_lo * f32(src1[src1_offset]);
218+
sum += q_hi * f32(src1[src1_offset + 16]);
219+
}
220+
}
221+
return sum;
222+
}
223+
#enddecl(Q5_0)
224+
225+
#decl(Q5_1)
226+
struct q5_1 {
227+
d: f16,
228+
m: f16,
229+
qh: u32,
230+
qs: array<u32, 4>
231+
};
232+
233+
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
234+
let block_q5_1 = src0[src0_idx_base + offset];
235+
let d = f32(block_q5_1.d);
236+
let m = f32(block_q5_1.m);
237+
var sum: f32 = 0.0;
238+
for (var j: u32 = 0; j < 4; j++) {
239+
let q_packed = block_q5_1.qs[j];
240+
for (var k: u32 = 0; k < 4; k++) {
241+
let q_byte = (q_packed >> (k * 8)) & 0xFF;
242+
let qh_hi = (block_q5_1.qh >> (j * 4 + k + 12)) & 0x10;
243+
let q_hi = f32(((q_byte >> 4) & 0xF) | qh_hi) * d + m;
244+
let qh_lo = ((block_q5_1.qh >> (j * 4 + k)) << 4) & 0x10;
245+
let q_lo = f32((q_byte & 0xF) | qh_lo) * d + m;
246+
let src1_offset = src1_idx_base + offset * 32 + j * 4 + k;
247+
sum += q_lo * f32(src1[src1_offset]);
248+
sum += q_hi * f32(src1[src1_offset + 16]);
249+
}
250+
}
251+
return sum;
252+
}
253+
#enddecl(Q5_1)
254+
255+
#decl(Q8_0)
256+
struct q8_0 {
257+
d: f16,
258+
qs: array<f16, 16>
259+
};
260+
261+
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
262+
let block_q8_0 = src0[src0_idx_base + offset];
263+
let d = f32(block_q8_0.d);
264+
var sum: f32 = 0.0;
265+
for (var j: u32 = 0; j < 8; j++) {
266+
let q_packed = bitcast<u32>(vec2(block_q8_0.qs[2 * j], block_q8_0.qs[2 * j + 1]));
267+
for (var k: u32 = 0; k < 4; k++) {
268+
let q_byte = bitcast<i32>((((q_packed >> (k * 8)) & 0xFF) << 24)) >> 24; // sign-extend
269+
let q_val = f32(q_byte) * d;
270+
let src1_offset = src1_idx_base + offset * 32 + j * 4 + k;
271+
sum += q_val * f32(src1[src1_offset]);
272+
}
273+
}
274+
return sum;
275+
}
276+
#enddecl(Q8_0)
277+
278+
#decl(Q8_1)
279+
struct q8_1 {
280+
d: f16,
281+
m: f16,
282+
qs: array<u32, 8>
283+
};
284+
285+
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
286+
let block_q8_1 = src0[src0_idx_base + offset];
287+
let d = f32(block_q8_1.d);
288+
let m = f32(block_q8_1.m);
289+
var sum: f32 = 0.0;
290+
for (var j: u32 = 0; j < 8; j++) {
291+
let q_packed = block_q8_1.qs[j];
292+
for (var k: u32 = 0; k < 4; k++) {
293+
let q_byte = bitcast<i32>((((q_packed >> (k * 8)) & 0xFF) << 24)) >> 24; // sign-extend
294+
let q_val = f32(q_byte) * d + m;
295+
let src1_offset = src1_idx_base + offset * 32 + j * 4 + k;
296+
sum += q_val * f32(src1[src1_offset]);
297+
}
298+
}
299+
return sum;
300+
}
301+
#enddecl(Q8_1)
302+
89303
#end(DECLS)
90304

91305
#define(SHADER)

0 commit comments

Comments
 (0)