|
80 | 80 | "BLOCK_SIZE": 256 |
81 | 81 | }, |
82 | 82 | "DECLS": "Q3_K" |
| 83 | + }, |
| 84 | + { |
| 85 | + "REPLS": { |
| 86 | + "SRC0_TYPE": "q4_k", |
| 87 | + "SRC1_TYPE": "f32", |
| 88 | + "BLOCK_SIZE": 256 |
| 89 | + }, |
| 90 | + "DECLS": "Q4_K" |
| 91 | + }, |
| 92 | + { |
| 93 | + "REPLS": { |
| 94 | + "SRC0_TYPE": "q5_k", |
| 95 | + "SRC1_TYPE": "f32", |
| 96 | + "BLOCK_SIZE": 256 |
| 97 | + }, |
| 98 | + "DECLS": "Q5_K" |
| 99 | + }, |
| 100 | + { |
| 101 | + "REPLS": { |
| 102 | + "SRC0_TYPE": "q6_k", |
| 103 | + "SRC1_TYPE": "f32", |
| 104 | + "BLOCK_SIZE": 256 |
| 105 | + }, |
| 106 | + "DECLS": "Q6_K" |
83 | 107 | } |
84 | 108 | ] |
85 | 109 |
|
@@ -320,7 +344,7 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { |
320 | 344 | scale_vals[0] = (scale_vals[0] & kmask2) | ((tmp & kmask1) << 4); |
321 | 345 | scale_vals[1] = (scale_vals[1] & kmask2) | (((tmp >> 2) & kmask1) << 4); |
322 | 346 |
|
323 | | - // convert half-precision floats to packed 32-bit integers |
| 347 | + // convert arrays of f16 -> u32 |
324 | 348 | var hmask_vals: array<u32, 8>; |
325 | 349 | for (var i: u32 = 0; i < 8; i++) { |
326 | 350 | hmask_vals[i] = bitcast<u32>(vec2(block.hmask[2 * i], block.hmask[2 * i + 1])); |
@@ -362,6 +386,181 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { |
362 | 386 |
|
363 | 387 | #enddecl(Q3_K) |
364 | 388 |
|
| 389 | +#decl(Q4_K) |
| 390 | +// 8 blocks of 32 elements each |
| 391 | +struct q4_k { |
| 392 | + d: f16, |
| 393 | + dmin: f16, |
| 394 | + scales: array<u32, 3>, |
| 395 | + qs: array<u32, 32> |
| 396 | +}; |
| 397 | + |
| 398 | +fn get_scale_min(is: u32, scales: array<u32, 3>) -> vec2<f32> { |
| 399 | + if (is < 4) { |
| 400 | + let sc_byte = (scales[is / 4] >> ((is % 4) * 8)) & 0xFF; |
| 401 | + let min_byte = (scales[(is + 4) / 4] >> ((is % 4) * 8)) & 0xFF; |
| 402 | + return vec2(f32(sc_byte & 63), f32(min_byte & 63)); |
| 403 | + } else { |
| 404 | + let sc_min_lo = (scales[(is + 4) / 4] >> (((is + 4) % 4) * 8)) & 0xFF; |
| 405 | + let sc_hi = (scales[(is - 4) / 4] >> (((is - 4) % 4) * 8)) & 0xFF; |
| 406 | + let min_hi = (scales[is / 4] >> ((is % 4) * 8)) & 0xFF; |
| 407 | + let sc = (sc_min_lo & 0xF) | ((sc_hi >> 6) << 4); |
| 408 | + let m = (sc_min_lo >> 4) | ((min_hi >> 6) << 4); |
| 409 | + return vec2(f32(sc), f32(m)); |
| 410 | + } |
| 411 | +} |
| 412 | + |
| 413 | +fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { |
| 414 | + let block = src0[src0_idx_base + offset]; |
| 415 | + let d = f32(block.d); |
| 416 | + let m = f32(block.dmin); |
| 417 | + var sum = 0.0; |
| 418 | + var src1_i = src1_idx_base + offset * 256; |
| 419 | + var is: u32 = 0; |
| 420 | + // 2 blocks each iteration |
| 421 | + for (var q_b_idx: u32 = 0; q_b_idx < 128; q_b_idx += 32) { |
| 422 | + for (var shift: u32 = 0; shift < 8; shift += 4) { |
| 423 | + let scale_min = get_scale_min(is, block.scales); |
| 424 | + is++; |
| 425 | + let dl = d * scale_min.x; |
| 426 | + let ml = m * scale_min.y; |
| 427 | + for (var l: u32 = 0; l < 32; l++) { |
| 428 | + let q_idx = q_b_idx + l; |
| 429 | + let q_byte = (block.qs[q_idx / 4] >> ((q_idx % 4) * 8)) & 0xFF; |
| 430 | + let qs_val = (q_byte >> shift) & 0xF; |
| 431 | + sum += (f32(qs_val) * dl - ml) * src1[src1_i]; |
| 432 | + src1_i++; |
| 433 | + } |
| 434 | + } |
| 435 | + } |
| 436 | + return sum; |
| 437 | +} |
| 438 | + |
| 439 | +#enddecl(Q4_K) |
| 440 | + |
| 441 | +#decl(Q5_K) |
| 442 | +// 8 blocks of 32 elements each |
| 443 | +struct q5_k { |
| 444 | + d: f16, |
| 445 | + dmin: f16, |
| 446 | + scales: array<u32, 3>, |
| 447 | + qh: array<u32, 8>, |
| 448 | + qs: array<u32, 32> |
| 449 | +}; |
| 450 | + |
| 451 | +fn get_scale_min(is: u32, scales: array<u32, 3>) -> vec2<f32> { |
| 452 | + if (is < 4) { |
| 453 | + let sc_byte = (scales[is / 4] >> ((is % 4) * 8)) & 0xFF; |
| 454 | + let min_byte = (scales[(is + 4) / 4] >> ((is % 4) * 8)) & 0xFF; |
| 455 | + return vec2(f32(sc_byte & 63), f32(min_byte & 63)); |
| 456 | + } else { |
| 457 | + let sc_min_lo = (scales[(is + 4) / 4] >> (((is + 4) % 4) * 8)) & 0xFF; |
| 458 | + let sc_hi = (scales[(is - 4) / 4] >> (((is - 4) % 4) * 8)) & 0xFF; |
| 459 | + let min_hi = (scales[is / 4] >> ((is % 4) * 8)) & 0xFF; |
| 460 | + let sc = (sc_min_lo & 0xF) | ((sc_hi >> 6) << 4); |
| 461 | + let m = (sc_min_lo >> 4) | ((min_hi >> 6) << 4); |
| 462 | + return vec2(f32(sc), f32(m)); |
| 463 | + } |
| 464 | +} |
| 465 | + |
| 466 | +fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { |
| 467 | + let block = src0[src0_idx_base + offset]; |
| 468 | + let d = f32(block.d); |
| 469 | + let m = f32(block.dmin); |
| 470 | + var sum = 0.0; |
| 471 | + var src1_i = src1_idx_base + offset * 256; |
| 472 | + var is: u32 = 0; |
| 473 | + var u: u32 = 1; |
| 474 | + // 2 blocks each iteration |
| 475 | + for (var q_b_idx: u32 = 0; q_b_idx < 128; q_b_idx += 32) { |
| 476 | + for (var shift: u32 = 0; shift < 8; shift += 4) { |
| 477 | + let scale_min = get_scale_min(is, block.scales); |
| 478 | + is++; |
| 479 | + let dl = d * scale_min.x; |
| 480 | + let ml = m * scale_min.y; |
| 481 | + for (var l: u32 = 0; l < 32; l++) { |
| 482 | + let q_idx = q_b_idx + l; |
| 483 | + let q_byte = (block.qs[q_idx / 4] >> ((q_idx % 4) * 8)) & 0xFF; |
| 484 | + let qh_byte = (block.qh[l / 4] >> ((l % 4) * 8)) & 0xFF; |
| 485 | + let qs_val = (q_byte >> shift) & 0xF; |
| 486 | + let qh_val = select(0.0, 16.0, (qh_byte & u) != 0); |
| 487 | + sum += ((f32(qs_val) + qh_val) * dl - ml) * src1[src1_i]; |
| 488 | + src1_i++; |
| 489 | + } |
| 490 | + u <<= 1; |
| 491 | + } |
| 492 | + } |
| 493 | + return sum; |
| 494 | +} |
| 495 | + |
| 496 | +#enddecl(Q5_K) |
| 497 | + |
| 498 | +#decl(Q6_K) |
| 499 | +// 16 blocks of 16 elements each |
| 500 | +struct q6_k { |
| 501 | + ql: array<f16, 64>, |
| 502 | + qh: array<f16, 32>, |
| 503 | + scales: array<f16, 8>, |
| 504 | + d: f16 |
| 505 | +}; |
| 506 | + |
| 507 | +fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { |
| 508 | + let block = src0[src0_idx_base + offset]; |
| 509 | + let d = f32(block.d); |
| 510 | + |
| 511 | + // convert arrays of f16 -> u32 |
| 512 | + var ql_vals: array<u32, 32>; |
| 513 | + for (var i: u32 = 0; i < 32; i++) { |
| 514 | + ql_vals[i] = bitcast<u32>(vec2(block.ql[2 * i], block.ql[2 * i + 1])); |
| 515 | + } |
| 516 | + var qh_vals: array<u32, 16>; |
| 517 | + for (var i: u32 = 0; i < 16; i++) { |
| 518 | + qh_vals[i] = bitcast<u32>(vec2(block.qh[2 * i], block.qh[2 * i + 1])); |
| 519 | + } |
| 520 | + var scale_vals: array<u32, 4>; |
| 521 | + for (var i: u32 = 0; i < 4; i++) { |
| 522 | + scale_vals[i] = bitcast<u32>(vec2(block.scales[2 * i], block.scales[2 * i + 1])); |
| 523 | + } |
| 524 | + |
| 525 | + var sum = 0.0; |
| 526 | + var src1_i = src1_idx_base + offset * 256; |
| 527 | + var qh_b_idx: u32 = 0; |
| 528 | + var sc_b_idx: u32 = 0; |
| 529 | + for (var ql_b_idx: u32 = 0; ql_b_idx < 128; ql_b_idx += 64) { |
| 530 | + for (var l: u32 = 0; l < 32; l++) { |
| 531 | + let ql13_b = (ql_vals[(ql_b_idx + l) / 4] >> (((ql_b_idx + l) % 4) * 8)) & 0xFF; |
| 532 | + let ql24_b = (ql_vals[(ql_b_idx + l + 32) / 4] >> (((ql_b_idx + l + 32) % 4) * 8)) & 0xFF; |
| 533 | + let qh_b = ((qh_vals[(qh_b_idx + l) / 4] >> (((qh_b_idx + l) % 4) * 8))) & 0xFF; |
| 534 | + |
| 535 | + let q1 = f32((ql13_b & 0xF) | ((qh_b & 3) << 4)) - 32.0; |
| 536 | + let q2 = f32((ql24_b & 0xF) | (((qh_b >> 2) & 3) << 4)) - 32.0; |
| 537 | + let q3 = f32((ql13_b >> 4) | (((qh_b >> 4) & 3) << 4)) - 32.0; |
| 538 | + let q4 = f32((ql24_b >> 4) | (((qh_b >> 6) & 3) << 4)) - 32.0; |
| 539 | + |
| 540 | + let is = l/16; |
| 541 | + let is1 = sc_b_idx + is; |
| 542 | + let sc1 = bitcast<i32>(((scale_vals[is1 / 4] >> ((is1 % 4) * 8)) & 0xFF) << 24) >> 24; |
| 543 | + let is2 = sc_b_idx + is + 2; |
| 544 | + let sc2 = bitcast<i32>(((scale_vals[is2 / 4] >> ((is2 % 4) * 8)) & 0xFF) << 24) >> 24; |
| 545 | + let is3 = sc_b_idx + is + 4; |
| 546 | + let sc3 = bitcast<i32>(((scale_vals[is3 / 4] >> ((is3 % 4) * 8)) & 0xFF) << 24) >> 24; |
| 547 | + let is4 = sc_b_idx + is + 6; |
| 548 | + let sc4 = bitcast<i32>(((scale_vals[is4 / 4] >> ((is4 % 4) * 8)) & 0xFF) << 24) >> 24; |
| 549 | + |
| 550 | + sum += d * f32(sc1) * q1 * src1[src1_i + l]; |
| 551 | + sum += d * f32(sc2) * q2 * src1[src1_i + l + 32]; |
| 552 | + sum += d * f32(sc3) * q3 * src1[src1_i + l + 64]; |
| 553 | + sum += d * f32(sc4) * q4 * src1[src1_i + l + 96]; |
| 554 | + } |
| 555 | + src1_i += 128; |
| 556 | + qh_b_idx += 32; |
| 557 | + sc_b_idx += 8; |
| 558 | + } |
| 559 | + return sum; |
| 560 | +} |
| 561 | + |
| 562 | +#enddecl(Q6_K) |
| 563 | + |
365 | 564 | #end(DECLS) |
366 | 565 |
|
367 | 566 | #define(SHADER) |
|
0 commit comments