Skip to content

Commit d76e562

Browse files
committed
Add rest of k-quants
1 parent de4da87 commit d76e562

File tree

2 files changed

+217
-3
lines changed

2 files changed

+217
-3
lines changed

ggml/src/ggml-webgpu/ggml-webgpu.cpp

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ struct webgpu_context_struct {
129129
webgpu_buf_pool set_rows_error_buf_pool;
130130

131131
wgpu::ComputePipeline memset_pipeline;
132-
wgpu::ComputePipeline mul_mat_pipeline[12][2];
132+
wgpu::ComputePipeline mul_mat_pipeline[15][2];
133133
wgpu::ComputePipeline set_rows_pipeline;
134134
wgpu::ComputePipeline cpy_pipeline;
135135

@@ -910,7 +910,7 @@ static void ggml_webgpu_init_memset_pipeline(webgpu_context & webgpu_ctx) {
910910
}
911911

912912
static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) {
913-
webgpu_pipeline_info pipeline_infos[10] = {
913+
webgpu_pipeline_info pipeline_infos[13] = {
914914
{ .name = "mul_mat_f32_f32",
915915
.shader_code = wgsl_mul_mat_f32_f32,
916916
.src0_type = GGML_TYPE_F32,
@@ -950,6 +950,18 @@ static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) {
950950
{ .name = "mul_mat_q3_k_f32",
951951
.shader_code = wgsl_mul_mat_q3_k_f32,
952952
.src0_type = GGML_TYPE_Q3_K,
953+
.src1_type = GGML_TYPE_F32 },
954+
{ .name = "mul_mat_q4_k_f32",
955+
.shader_code = wgsl_mul_mat_q4_k_f32,
956+
.src0_type = GGML_TYPE_Q4_K,
957+
.src1_type = GGML_TYPE_F32 },
958+
{ .name = "mul_mat_q5_k_f32",
959+
.shader_code = wgsl_mul_mat_q5_k_f32,
960+
.src0_type = GGML_TYPE_Q5_K,
961+
.src1_type = GGML_TYPE_F32 },
962+
{ .name = "mul_mat_q6_k_f32",
963+
.shader_code = wgsl_mul_mat_q6_k_f32,
964+
.src0_type = GGML_TYPE_Q6_K,
953965
.src1_type = GGML_TYPE_F32 }
954966
};
955967

@@ -1049,6 +1061,9 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
10491061
case GGML_TYPE_Q8_0:
10501062
case GGML_TYPE_Q2_K:
10511063
case GGML_TYPE_Q3_K:
1064+
case GGML_TYPE_Q4_K:
1065+
case GGML_TYPE_Q5_K:
1066+
case GGML_TYPE_Q6_K:
10521067
return true;
10531068
default:
10541069
return false;

ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl

Lines changed: 200 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,30 @@
8080
"BLOCK_SIZE": 256
8181
},
8282
"DECLS": "Q3_K"
83+
},
84+
{
85+
"REPLS": {
86+
"SRC0_TYPE": "q4_k",
87+
"SRC1_TYPE": "f32",
88+
"BLOCK_SIZE": 256
89+
},
90+
"DECLS": "Q4_K"
91+
},
92+
{
93+
"REPLS": {
94+
"SRC0_TYPE": "q5_k",
95+
"SRC1_TYPE": "f32",
96+
"BLOCK_SIZE": 256
97+
},
98+
"DECLS": "Q5_K"
99+
},
100+
{
101+
"REPLS": {
102+
"SRC0_TYPE": "q6_k",
103+
"SRC1_TYPE": "f32",
104+
"BLOCK_SIZE": 256
105+
},
106+
"DECLS": "Q6_K"
83107
}
84108
]
85109

@@ -320,7 +344,7 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
320344
scale_vals[0] = (scale_vals[0] & kmask2) | ((tmp & kmask1) << 4);
321345
scale_vals[1] = (scale_vals[1] & kmask2) | (((tmp >> 2) & kmask1) << 4);
322346

323-
// convert half-precision floats to packed 32-bit integers
347+
// convert arrays of f16 -> u32
324348
var hmask_vals: array<u32, 8>;
325349
for (var i: u32 = 0; i < 8; i++) {
326350
hmask_vals[i] = bitcast<u32>(vec2(block.hmask[2 * i], block.hmask[2 * i + 1]));
@@ -362,6 +386,181 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
362386

363387
#enddecl(Q3_K)
364388

389+
#decl(Q4_K)
390+
// 8 blocks of 32 elements each
391+
struct q4_k {
392+
d: f16,
393+
dmin: f16,
394+
scales: array<u32, 3>,
395+
qs: array<u32, 32>
396+
};
397+
398+
fn get_scale_min(is: u32, scales: array<u32, 3>) -> vec2<f32> {
399+
if (is < 4) {
400+
let sc_byte = (scales[is / 4] >> ((is % 4) * 8)) & 0xFF;
401+
let min_byte = (scales[(is + 4) / 4] >> ((is % 4) * 8)) & 0xFF;
402+
return vec2(f32(sc_byte & 63), f32(min_byte & 63));
403+
} else {
404+
let sc_min_lo = (scales[(is + 4) / 4] >> (((is + 4) % 4) * 8)) & 0xFF;
405+
let sc_hi = (scales[(is - 4) / 4] >> (((is - 4) % 4) * 8)) & 0xFF;
406+
let min_hi = (scales[is / 4] >> ((is % 4) * 8)) & 0xFF;
407+
let sc = (sc_min_lo & 0xF) | ((sc_hi >> 6) << 4);
408+
let m = (sc_min_lo >> 4) | ((min_hi >> 6) << 4);
409+
return vec2(f32(sc), f32(m));
410+
}
411+
}
412+
413+
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
414+
let block = src0[src0_idx_base + offset];
415+
let d = f32(block.d);
416+
let m = f32(block.dmin);
417+
var sum = 0.0;
418+
var src1_i = src1_idx_base + offset * 256;
419+
var is: u32 = 0;
420+
// 2 blocks each iteration
421+
for (var q_b_idx: u32 = 0; q_b_idx < 128; q_b_idx += 32) {
422+
for (var shift: u32 = 0; shift < 8; shift += 4) {
423+
let scale_min = get_scale_min(is, block.scales);
424+
is++;
425+
let dl = d * scale_min.x;
426+
let ml = m * scale_min.y;
427+
for (var l: u32 = 0; l < 32; l++) {
428+
let q_idx = q_b_idx + l;
429+
let q_byte = (block.qs[q_idx / 4] >> ((q_idx % 4) * 8)) & 0xFF;
430+
let qs_val = (q_byte >> shift) & 0xF;
431+
sum += (f32(qs_val) * dl - ml) * src1[src1_i];
432+
src1_i++;
433+
}
434+
}
435+
}
436+
return sum;
437+
}
438+
439+
#enddecl(Q4_K)
440+
441+
#decl(Q5_K)
442+
// 8 blocks of 32 elements each
443+
struct q5_k {
444+
d: f16,
445+
dmin: f16,
446+
scales: array<u32, 3>,
447+
qh: array<u32, 8>,
448+
qs: array<u32, 32>
449+
};
450+
451+
fn get_scale_min(is: u32, scales: array<u32, 3>) -> vec2<f32> {
452+
if (is < 4) {
453+
let sc_byte = (scales[is / 4] >> ((is % 4) * 8)) & 0xFF;
454+
let min_byte = (scales[(is + 4) / 4] >> ((is % 4) * 8)) & 0xFF;
455+
return vec2(f32(sc_byte & 63), f32(min_byte & 63));
456+
} else {
457+
let sc_min_lo = (scales[(is + 4) / 4] >> (((is + 4) % 4) * 8)) & 0xFF;
458+
let sc_hi = (scales[(is - 4) / 4] >> (((is - 4) % 4) * 8)) & 0xFF;
459+
let min_hi = (scales[is / 4] >> ((is % 4) * 8)) & 0xFF;
460+
let sc = (sc_min_lo & 0xF) | ((sc_hi >> 6) << 4);
461+
let m = (sc_min_lo >> 4) | ((min_hi >> 6) << 4);
462+
return vec2(f32(sc), f32(m));
463+
}
464+
}
465+
466+
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
467+
let block = src0[src0_idx_base + offset];
468+
let d = f32(block.d);
469+
let m = f32(block.dmin);
470+
var sum = 0.0;
471+
var src1_i = src1_idx_base + offset * 256;
472+
var is: u32 = 0;
473+
var u: u32 = 1;
474+
// 2 blocks each iteration
475+
for (var q_b_idx: u32 = 0; q_b_idx < 128; q_b_idx += 32) {
476+
for (var shift: u32 = 0; shift < 8; shift += 4) {
477+
let scale_min = get_scale_min(is, block.scales);
478+
is++;
479+
let dl = d * scale_min.x;
480+
let ml = m * scale_min.y;
481+
for (var l: u32 = 0; l < 32; l++) {
482+
let q_idx = q_b_idx + l;
483+
let q_byte = (block.qs[q_idx / 4] >> ((q_idx % 4) * 8)) & 0xFF;
484+
let qh_byte = (block.qh[l / 4] >> ((l % 4) * 8)) & 0xFF;
485+
let qs_val = (q_byte >> shift) & 0xF;
486+
let qh_val = select(0.0, 16.0, (qh_byte & u) != 0);
487+
sum += ((f32(qs_val) + qh_val) * dl - ml) * src1[src1_i];
488+
src1_i++;
489+
}
490+
u <<= 1;
491+
}
492+
}
493+
return sum;
494+
}
495+
496+
#enddecl(Q5_K)
497+
498+
#decl(Q6_K)
499+
// 16 blocks of 16 elements each
500+
struct q6_k {
501+
ql: array<f16, 64>,
502+
qh: array<f16, 32>,
503+
scales: array<f16, 8>,
504+
d: f16
505+
};
506+
507+
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
508+
let block = src0[src0_idx_base + offset];
509+
let d = f32(block.d);
510+
511+
// convert arrays of f16 -> u32
512+
var ql_vals: array<u32, 32>;
513+
for (var i: u32 = 0; i < 32; i++) {
514+
ql_vals[i] = bitcast<u32>(vec2(block.ql[2 * i], block.ql[2 * i + 1]));
515+
}
516+
var qh_vals: array<u32, 16>;
517+
for (var i: u32 = 0; i < 16; i++) {
518+
qh_vals[i] = bitcast<u32>(vec2(block.qh[2 * i], block.qh[2 * i + 1]));
519+
}
520+
var scale_vals: array<u32, 4>;
521+
for (var i: u32 = 0; i < 4; i++) {
522+
scale_vals[i] = bitcast<u32>(vec2(block.scales[2 * i], block.scales[2 * i + 1]));
523+
}
524+
525+
var sum = 0.0;
526+
var src1_i = src1_idx_base + offset * 256;
527+
var qh_b_idx: u32 = 0;
528+
var sc_b_idx: u32 = 0;
529+
for (var ql_b_idx: u32 = 0; ql_b_idx < 128; ql_b_idx += 64) {
530+
for (var l: u32 = 0; l < 32; l++) {
531+
let ql13_b = (ql_vals[(ql_b_idx + l) / 4] >> (((ql_b_idx + l) % 4) * 8)) & 0xFF;
532+
let ql24_b = (ql_vals[(ql_b_idx + l + 32) / 4] >> (((ql_b_idx + l + 32) % 4) * 8)) & 0xFF;
533+
let qh_b = ((qh_vals[(qh_b_idx + l) / 4] >> (((qh_b_idx + l) % 4) * 8))) & 0xFF;
534+
535+
let q1 = f32((ql13_b & 0xF) | ((qh_b & 3) << 4)) - 32.0;
536+
let q2 = f32((ql24_b & 0xF) | (((qh_b >> 2) & 3) << 4)) - 32.0;
537+
let q3 = f32((ql13_b >> 4) | (((qh_b >> 4) & 3) << 4)) - 32.0;
538+
let q4 = f32((ql24_b >> 4) | (((qh_b >> 6) & 3) << 4)) - 32.0;
539+
540+
let is = l/16;
541+
let is1 = sc_b_idx + is;
542+
let sc1 = bitcast<i32>(((scale_vals[is1 / 4] >> ((is1 % 4) * 8)) & 0xFF) << 24) >> 24;
543+
let is2 = sc_b_idx + is + 2;
544+
let sc2 = bitcast<i32>(((scale_vals[is2 / 4] >> ((is2 % 4) * 8)) & 0xFF) << 24) >> 24;
545+
let is3 = sc_b_idx + is + 4;
546+
let sc3 = bitcast<i32>(((scale_vals[is3 / 4] >> ((is3 % 4) * 8)) & 0xFF) << 24) >> 24;
547+
let is4 = sc_b_idx + is + 6;
548+
let sc4 = bitcast<i32>(((scale_vals[is4 / 4] >> ((is4 % 4) * 8)) & 0xFF) << 24) >> 24;
549+
550+
sum += d * f32(sc1) * q1 * src1[src1_i + l];
551+
sum += d * f32(sc2) * q2 * src1[src1_i + l + 32];
552+
sum += d * f32(sc3) * q3 * src1[src1_i + l + 64];
553+
sum += d * f32(sc4) * q4 * src1[src1_i + l + 96];
554+
}
555+
src1_i += 128;
556+
qh_b_idx += 32;
557+
sc_b_idx += 8;
558+
}
559+
return sum;
560+
}
561+
562+
#enddecl(Q6_K)
563+
365564
#end(DECLS)
366565

367566
#define(SHADER)

0 commit comments

Comments
 (0)