Skip to content

Commit de4da87

Browse files
committed
Add q2_k and q3_k quantization
1 parent c3611f9 commit de4da87

File tree

2 files changed

+144
-105
lines changed

2 files changed

+144
-105
lines changed

ggml/src/ggml-webgpu/ggml-webgpu.cpp

Lines changed: 27 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ struct webgpu_context_struct {
129129
webgpu_buf_pool set_rows_error_buf_pool;
130130

131131
wgpu::ComputePipeline memset_pipeline;
132-
wgpu::ComputePipeline mul_mat_pipeline[10][2];
132+
wgpu::ComputePipeline mul_mat_pipeline[12][2];
133133
wgpu::ComputePipeline set_rows_pipeline;
134134
wgpu::ComputePipeline cpy_pipeline;
135135

@@ -910,7 +910,7 @@ static void ggml_webgpu_init_memset_pipeline(webgpu_context & webgpu_ctx) {
910910
}
911911

912912
static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) {
913-
webgpu_pipeline_info pipeline_infos[16] = {
913+
webgpu_pipeline_info pipeline_infos[10] = {
914914
{ .name = "mul_mat_f32_f32",
915915
.shader_code = wgsl_mul_mat_f32_f32,
916916
.src0_type = GGML_TYPE_F32,
@@ -919,10 +919,6 @@ static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) {
919919
.shader_code = wgsl_mul_mat_f16_f16,
920920
.src0_type = GGML_TYPE_F16,
921921
.src1_type = GGML_TYPE_F16 },
922-
{ .name = "mul_mat_f32_f16",
923-
.shader_code = wgsl_mul_mat_f32_f16,
924-
.src0_type = GGML_TYPE_F32,
925-
.src1_type = GGML_TYPE_F16 },
926922
{ .name = "mul_mat_f16_f32",
927923
.shader_code = wgsl_mul_mat_f16_f32,
928924
.src0_type = GGML_TYPE_F16,
@@ -931,50 +927,30 @@ static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) {
931927
.shader_code = wgsl_mul_mat_q4_0_f32,
932928
.src0_type = GGML_TYPE_Q4_0,
933929
.src1_type = GGML_TYPE_F32 },
934-
{ .name = "mul_mat_q4_0_f16",
935-
.shader_code = wgsl_mul_mat_q4_0_f16,
936-
.src0_type = GGML_TYPE_Q4_0,
937-
.src1_type = GGML_TYPE_F16 },
938930
{ .name = "mul_mat_q4_1_f32",
939931
.shader_code = wgsl_mul_mat_q4_1_f32,
940932
.src0_type = GGML_TYPE_Q4_1,
941933
.src1_type = GGML_TYPE_F32 },
942-
{ .name = "mul_mat_q4_1_f16",
943-
.shader_code = wgsl_mul_mat_q4_1_f16,
944-
.src0_type = GGML_TYPE_Q4_1,
945-
.src1_type = GGML_TYPE_F16 },
946934
{ .name = "mul_mat_q5_0_f32",
947935
.shader_code = wgsl_mul_mat_q5_0_f32,
948936
.src0_type = GGML_TYPE_Q5_0,
949937
.src1_type = GGML_TYPE_F32 },
950-
{ .name = "mul_mat_q5_0_f16",
951-
.shader_code = wgsl_mul_mat_q5_0_f16,
952-
.src0_type = GGML_TYPE_Q5_0,
953-
.src1_type = GGML_TYPE_F16 },
954938
{ .name = "mul_mat_q5_1_f32",
955939
.shader_code = wgsl_mul_mat_q5_1_f32,
956940
.src0_type = GGML_TYPE_Q5_1,
957941
.src1_type = GGML_TYPE_F32 },
958-
{ .name = "mul_mat_q5_1_f16",
959-
.shader_code = wgsl_mul_mat_q5_1_f16,
960-
.src0_type = GGML_TYPE_Q5_1,
961-
.src1_type = GGML_TYPE_F16 },
962942
{ .name = "mul_mat_q8_0_f32",
963943
.shader_code = wgsl_mul_mat_q8_0_f32,
964944
.src0_type = GGML_TYPE_Q8_0,
965945
.src1_type = GGML_TYPE_F32 },
966-
{ .name = "mul_mat_q8_0_f16",
967-
.shader_code = wgsl_mul_mat_q8_0_f16,
968-
.src0_type = GGML_TYPE_Q8_0,
969-
.src1_type = GGML_TYPE_F16 },
970-
{ .name = "mul_mat_q8_1_f32",
971-
.shader_code = wgsl_mul_mat_q8_1_f32,
972-
.src0_type = GGML_TYPE_Q8_1,
946+
{ .name = "mul_mat_q2_k_f32",
947+
.shader_code = wgsl_mul_mat_q2_k_f32,
948+
.src0_type = GGML_TYPE_Q2_K,
973949
.src1_type = GGML_TYPE_F32 },
974-
{ .name = "mul_mat_q8_1_f16",
975-
.shader_code = wgsl_mul_mat_q8_1_f16,
976-
.src0_type = GGML_TYPE_Q8_1,
977-
.src1_type = GGML_TYPE_F16 }
950+
{ .name = "mul_mat_q3_k_f32",
951+
.shader_code = wgsl_mul_mat_q3_k_f32,
952+
.src0_type = GGML_TYPE_Q3_K,
953+
.src1_type = GGML_TYPE_F32 }
978954
};
979955

980956
for (auto & pipeline_info : pipeline_infos) {
@@ -1058,28 +1034,29 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
10581034
case GGML_OP_CPY:
10591035
case GGML_OP_SET_ROWS:
10601036
return op->type == GGML_TYPE_F16 && op->src[0]->type == GGML_TYPE_F32;
1061-
case GGML_OP_MUL_MAT:
1062-
switch(op->src[0]->type) {
1063-
case GGML_TYPE_F32:
1064-
case GGML_TYPE_F16:
1065-
case GGML_TYPE_Q4_0:
1066-
case GGML_TYPE_Q4_1:
1067-
case GGML_TYPE_Q5_0:
1068-
case GGML_TYPE_Q5_1:
1069-
case GGML_TYPE_Q8_0:
1070-
case GGML_TYPE_Q8_1:
1071-
break;
1072-
default:
1073-
return false;
1074-
}
1037+
case GGML_OP_MUL_MAT: {
10751038
switch(op->src[1]->type) {
1076-
case GGML_TYPE_F32:
10771039
case GGML_TYPE_F16:
1078-
break;
1040+
return op->src[0]->type == GGML_TYPE_F16;
1041+
case GGML_TYPE_F32:
1042+
switch(op->src[0]->type) {
1043+
case GGML_TYPE_F32:
1044+
case GGML_TYPE_F16:
1045+
case GGML_TYPE_Q4_0:
1046+
case GGML_TYPE_Q4_1:
1047+
case GGML_TYPE_Q5_0:
1048+
case GGML_TYPE_Q5_1:
1049+
case GGML_TYPE_Q8_0:
1050+
case GGML_TYPE_Q2_K:
1051+
case GGML_TYPE_Q3_K:
1052+
return true;
1053+
default:
1054+
return false;
1055+
}
10791056
default:
10801057
return false;
10811058
}
1082-
return true;
1059+
}
10831060
default:
10841061
return false;
10851062
}

ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl

Lines changed: 117 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,6 @@
2525
},
2626
"DECLS" : "FLOAT"
2727
},
28-
{
29-
"REPLS": {
30-
"SRC0_TYPE" : "f32",
31-
"SRC1_TYPE" : "f16",
32-
"BLOCK_SIZE" : 1
33-
},
34-
"DECLS" : "FLOAT"
35-
},
3628
{
3729
"REPLS": {
3830
"SRC0_TYPE": "q4_0",
@@ -41,14 +33,6 @@
4133
},
4234
"DECLS": "Q4_0"
4335
},
44-
{
45-
"REPLS": {
46-
"SRC0_TYPE": "q4_0",
47-
"SRC1_TYPE": "f16",
48-
"BLOCK_SIZE": 32
49-
},
50-
"DECLS": "Q4_0"
51-
},
5236
{
5337
"REPLS": {
5438
"SRC0_TYPE": "q4_1",
@@ -57,14 +41,6 @@
5741
},
5842
"DECLS": "Q4_1"
5943
},
60-
{
61-
"REPLS": {
62-
"SRC0_TYPE": "q4_1",
63-
"SRC1_TYPE": "f16",
64-
"BLOCK_SIZE": 32
65-
},
66-
"DECLS": "Q4_1"
67-
},
6844
{
6945
"REPLS": {
7046
"SRC0_TYPE": "q5_0",
@@ -73,14 +49,6 @@
7349
},
7450
"DECLS": "Q5_0"
7551
},
76-
{
77-
"REPLS": {
78-
"SRC0_TYPE": "q5_0",
79-
"SRC1_TYPE": "f16",
80-
"BLOCK_SIZE": 32
81-
},
82-
"DECLS": "Q5_0"
83-
},
8452
{
8553
"REPLS": {
8654
"SRC0_TYPE": "q5_1",
@@ -89,14 +57,6 @@
8957
},
9058
"DECLS": "Q5_1"
9159
},
92-
{
93-
"REPLS": {
94-
"SRC0_TYPE": "q5_1",
95-
"SRC1_TYPE": "f16",
96-
"BLOCK_SIZE": 32
97-
},
98-
"DECLS": "Q5_1"
99-
},
10060
{
10161
"REPLS": {
10262
"SRC0_TYPE": "q8_0",
@@ -107,27 +67,19 @@
10767
},
10868
{
10969
"REPLS": {
110-
"SRC0_TYPE": "q8_0",
111-
"SRC1_TYPE": "f16",
112-
"BLOCK_SIZE": 32
113-
},
114-
"DECLS": "Q8_0"
115-
},
116-
{
117-
"REPLS": {
118-
"SRC0_TYPE": "q8_1",
70+
"SRC0_TYPE": "q2_k",
11971
"SRC1_TYPE": "f32",
120-
"BLOCK_SIZE": 32
72+
"BLOCK_SIZE": 256
12173
},
122-
"DECLS": "Q8_1"
74+
"DECLS": "Q2_K"
12375
},
12476
{
12577
"REPLS": {
126-
"SRC0_TYPE": "q8_1",
127-
"SRC1_TYPE": "f16",
128-
"BLOCK_SIZE": 32
78+
"SRC0_TYPE": "q3_k",
79+
"SRC1_TYPE": "f32",
80+
"BLOCK_SIZE": 256
12981
},
130-
"DECLS": "Q8_1"
82+
"DECLS": "Q3_K"
13183
}
13284
]
13385

@@ -300,6 +252,116 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
300252
}
301253
#enddecl(Q8_1)
302254

255+
#decl(Q2_K)
256+
// 16 blocks of 16 elements each
257+
struct q2_k {
258+
scales: array<u32, 4>,
259+
qs: array<u32, 16>,
260+
d: f16,
261+
dmin: f16
262+
};
263+
264+
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
265+
let block = src0[src0_idx_base + offset];
266+
let d = f32(block.d);
267+
let m = f32(block.dmin);
268+
var sum = 0.0;
269+
var src1_i = src1_idx_base + offset * 256;
270+
var is: u32 = 0;
271+
// 2 halves of the block (128 elements each)
272+
for (var q_b_idx: u32 = 0; q_b_idx < 64; q_b_idx += 32) {
273+
// 4 groups (each group has 2 blocks of 16 elements)
274+
for (var shift: u32 = 0; shift < 8; shift += 2) {
275+
// 2 blocks
276+
for (var k: u32 = 0; k < 32; k += 16) {
277+
let sc = (block.scales[is / 4] >> ((is % 4) * 8)) & 0xFF;
278+
is++;
279+
let dl = d * f32(sc & 0xF);
280+
let ml = m * f32(sc >> 4);
281+
for (var l: u32 = 0u; l < 16; l++) {
282+
let q_idx = q_b_idx + k + l;
283+
let q_byte = (block.qs[q_idx / 4] >> ((q_idx % 4) * 8)) & 0xFF;
284+
let qs_val = (q_byte >> shift) & 3;
285+
sum += (f32(qs_val) * dl - ml) * src1[src1_i];
286+
src1_i++;
287+
}
288+
}
289+
}
290+
}
291+
return sum;
292+
}
293+
294+
#enddecl(Q2_K)
295+
296+
#decl(Q3_K)
297+
// 16 blocks of 16 elements each
298+
struct q3_k {
299+
hmask: array<f16, 16>,
300+
qs: array<f16, 32>,
301+
scales: array<f16, 6>, // 6-bit quantized values
302+
d: f16
303+
};
304+
305+
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
306+
let block = src0[src0_idx_base + offset];
307+
let d = f32(block.d);
308+
309+
// extract 6-bit scales, which consist of 4-bits from first 8 bytes of scale,
310+
// and 2-bits from the last 4 bytes
311+
let kmask1: u32 = 0x03030303;
312+
let kmask2: u32 = 0x0f0f0f0f;
313+
var scale_vals: array<u32, 4>;
314+
for (var i: u32 = 0; i < 4; i++) {
315+
scale_vals[i] = bitcast<u32>(vec2(block.scales[2 * i], block.scales[2 * i + 1]));
316+
}
317+
var tmp: u32 = scale_vals[2];
318+
scale_vals[2] = ((scale_vals[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4);
319+
scale_vals[3] = ((scale_vals[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4);
320+
scale_vals[0] = (scale_vals[0] & kmask2) | ((tmp & kmask1) << 4);
321+
scale_vals[1] = (scale_vals[1] & kmask2) | (((tmp >> 2) & kmask1) << 4);
322+
323+
// convert half-precision floats to packed 32-bit integers
324+
var hmask_vals: array<u32, 8>;
325+
for (var i: u32 = 0; i < 8; i++) {
326+
hmask_vals[i] = bitcast<u32>(vec2(block.hmask[2 * i], block.hmask[2 * i + 1]));
327+
}
328+
var qs_vals: array<u32, 16>;
329+
for (var i: u32 = 0; i < 16; i++) {
330+
qs_vals[i] = bitcast<u32>(vec2(block.qs[2 * i], block.qs[2 * i + 1]));
331+
}
332+
333+
var sum = 0.0;
334+
var src1_i = src1_idx_base + offset * 256;
335+
var is: u32 = 0;
336+
var m: u32 = 1;
337+
// 2 halves of the block (128 elements each)
338+
for (var q_b_idx: u32 = 0; q_b_idx < 64; q_b_idx += 32) {
339+
// 4 groups (each group has 2 blocks of 16 elements)
340+
for (var shift: u32 = 0; shift < 8; shift += 2) {
341+
// 2 blocks
342+
for (var k: u32 = 0; k < 32; k += 16) {
343+
let sc = (scale_vals[is / 4] >> ((is % 4) * 8)) & 0xFF;
344+
is++;
345+
let dl = d * (f32(sc) - 32.0);
346+
for (var l: u32 = 0u; l < 16u; l++) {
347+
let q_idx = q_b_idx + k + l;
348+
let hm_idx = k + l;
349+
let q_byte = (qs_vals[q_idx / 4] >> ((q_idx % 4) * 8)) & 0xFF;
350+
let hmask_byte = (hmask_vals[hm_idx / 4] >> ((hm_idx % 4) * 8)) & 0xFF;
351+
let hm = select(4.0, 0.0, (hmask_byte & m) != 0);
352+
let qs_val = (q_byte >> shift) & 3;
353+
sum += ((f32(qs_val) - hm) * dl) * src1[src1_i];
354+
src1_i++;
355+
}
356+
}
357+
m <<= 1;
358+
}
359+
}
360+
return sum;
361+
}
362+
363+
#enddecl(Q3_K)
364+
303365
#end(DECLS)
304366

305367
#define(SHADER)

0 commit comments

Comments
 (0)