Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,13 @@ fn loadSHMA(batch:u32, a_global_base:u32, kidx_v:u32, row: u32, col: u32)
const b_weight_offset : u32 = 0;
let b_value = b.getByOffset(b_global * uniforms.K16 + kidx_v + col);
#endif
tile_B[col][row] = DequantizedFrom2BitsTo8Bits(b_value);
let block_idx = kidx_v/(block_size/16);
#if has_zero_points
let zero = mm_read_zero(b_global, block_idx, uniforms.N, uniforms.zero_blocks_per_col);
tile_B[col][row] = DequantizedFrom2BitsTo8Bits(b_value, zero);
#else
tile_B[col][row] = DequantizedFrom2BitsTo8Bits(b_value);
#endif
let b_scale_offset = uniforms.weight_idx * uniforms.N * (uniforms.K/block_size);
scale_B[row] = scales_b.getByOffset(b_scale_offset + b_global*(uniforms.K/block_size) + block_idx);
}
Expand All @@ -150,6 +155,11 @@ fn loadSHMA(batch:u32, a_global_base:u32, kidx_v:u32, row: u32, col: u32)
$MAIN {
#if n_bits == 2
LoadDequantizationTable(local_idx);
#if has_zero_points
LoadDequantizationTable(local_idx + 256);
LoadDequantizationTable(local_idx + 512);
LoadDequantizationTable(local_idx + 768);
#endif
workgroupBarrier();
#endif
// During the load phase we use all 256 threads to load 64 rows of A/B.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,287 @@

#if n_bits == 2
alias mul_precision = output_element_t;
#if has_zero_points
const lut_size = 1024;
var<workgroup> shm_dequantization_table : array<u32, lut_size>;
// 1024-entry LUT: 4 sections of 256 entries, one per zero_point value (0-3).
// Index as: zero * 256 + byte_value
const q2_dequantization_table = array<u32, lut_size>(
// zero_point = 0: entries 0-255
0x00000000, 0x00000001, 0x00000002, 0x00000003,
0x00000100, 0x00000101, 0x00000102, 0x00000103,
0x00000200, 0x00000201, 0x00000202, 0x00000203,
0x00000300, 0x00000301, 0x00000302, 0x00000303,
0x00010000, 0x00010001, 0x00010002, 0x00010003,
0x00010100, 0x00010101, 0x00010102, 0x00010103,
0x00010200, 0x00010201, 0x00010202, 0x00010203,
0x00010300, 0x00010301, 0x00010302, 0x00010303,
0x00020000, 0x00020001, 0x00020002, 0x00020003,
0x00020100, 0x00020101, 0x00020102, 0x00020103,
0x00020200, 0x00020201, 0x00020202, 0x00020203,
0x00020300, 0x00020301, 0x00020302, 0x00020303,
0x00030000, 0x00030001, 0x00030002, 0x00030003,
0x00030100, 0x00030101, 0x00030102, 0x00030103,
0x00030200, 0x00030201, 0x00030202, 0x00030203,
0x00030300, 0x00030301, 0x00030302, 0x00030303,
0x01000000, 0x01000001, 0x01000002, 0x01000003,
0x01000100, 0x01000101, 0x01000102, 0x01000103,
0x01000200, 0x01000201, 0x01000202, 0x01000203,
0x01000300, 0x01000301, 0x01000302, 0x01000303,
0x01010000, 0x01010001, 0x01010002, 0x01010003,
0x01010100, 0x01010101, 0x01010102, 0x01010103,
0x01010200, 0x01010201, 0x01010202, 0x01010203,
0x01010300, 0x01010301, 0x01010302, 0x01010303,
0x01020000, 0x01020001, 0x01020002, 0x01020003,
0x01020100, 0x01020101, 0x01020102, 0x01020103,
0x01020200, 0x01020201, 0x01020202, 0x01020203,
0x01020300, 0x01020301, 0x01020302, 0x01020303,
0x01030000, 0x01030001, 0x01030002, 0x01030003,
0x01030100, 0x01030101, 0x01030102, 0x01030103,
0x01030200, 0x01030201, 0x01030202, 0x01030203,
0x01030300, 0x01030301, 0x01030302, 0x01030303,
0x02000000, 0x02000001, 0x02000002, 0x02000003,
0x02000100, 0x02000101, 0x02000102, 0x02000103,
0x02000200, 0x02000201, 0x02000202, 0x02000203,
0x02000300, 0x02000301, 0x02000302, 0x02000303,
0x02010000, 0x02010001, 0x02010002, 0x02010003,
0x02010100, 0x02010101, 0x02010102, 0x02010103,
0x02010200, 0x02010201, 0x02010202, 0x02010203,
0x02010300, 0x02010301, 0x02010302, 0x02010303,
0x02020000, 0x02020001, 0x02020002, 0x02020003,
0x02020100, 0x02020101, 0x02020102, 0x02020103,
0x02020200, 0x02020201, 0x02020202, 0x02020203,
0x02020300, 0x02020301, 0x02020302, 0x02020303,
0x02030000, 0x02030001, 0x02030002, 0x02030003,
0x02030100, 0x02030101, 0x02030102, 0x02030103,
0x02030200, 0x02030201, 0x02030202, 0x02030203,
0x02030300, 0x02030301, 0x02030302, 0x02030303,
0x03000000, 0x03000001, 0x03000002, 0x03000003,
0x03000100, 0x03000101, 0x03000102, 0x03000103,
0x03000200, 0x03000201, 0x03000202, 0x03000203,
0x03000300, 0x03000301, 0x03000302, 0x03000303,
0x03010000, 0x03010001, 0x03010002, 0x03010003,
0x03010100, 0x03010101, 0x03010102, 0x03010103,
0x03010200, 0x03010201, 0x03010202, 0x03010203,
0x03010300, 0x03010301, 0x03010302, 0x03010303,
0x03020000, 0x03020001, 0x03020002, 0x03020003,
0x03020100, 0x03020101, 0x03020102, 0x03020103,
0x03020200, 0x03020201, 0x03020202, 0x03020203,
0x03020300, 0x03020301, 0x03020302, 0x03020303,
0x03030000, 0x03030001, 0x03030002, 0x03030003,
0x03030100, 0x03030101, 0x03030102, 0x03030103,
0x03030200, 0x03030201, 0x03030202, 0x03030203,
0x03030300, 0x03030301, 0x03030302, 0x03030303,
// zero_point = 1: entries 256-511
0xFFFFFFFF, 0xFFFFFF00, 0xFFFFFF01, 0xFFFFFF02,
0xFFFF00FF, 0xFFFF0000, 0xFFFF0001, 0xFFFF0002,
0xFFFF01FF, 0xFFFF0100, 0xFFFF0101, 0xFFFF0102,
0xFFFF02FF, 0xFFFF0200, 0xFFFF0201, 0xFFFF0202,
0xFF00FFFF, 0xFF00FF00, 0xFF00FF01, 0xFF00FF02,
0xFF0000FF, 0xFF000000, 0xFF000001, 0xFF000002,
0xFF0001FF, 0xFF000100, 0xFF000101, 0xFF000102,
0xFF0002FF, 0xFF000200, 0xFF000201, 0xFF000202,
0xFF01FFFF, 0xFF01FF00, 0xFF01FF01, 0xFF01FF02,
0xFF0100FF, 0xFF010000, 0xFF010001, 0xFF010002,
0xFF0101FF, 0xFF010100, 0xFF010101, 0xFF010102,
0xFF0102FF, 0xFF010200, 0xFF010201, 0xFF010202,
0xFF02FFFF, 0xFF02FF00, 0xFF02FF01, 0xFF02FF02,
0xFF0200FF, 0xFF020000, 0xFF020001, 0xFF020002,
0xFF0201FF, 0xFF020100, 0xFF020101, 0xFF020102,
0xFF0202FF, 0xFF020200, 0xFF020201, 0xFF020202,
0x00FFFFFF, 0x00FFFF00, 0x00FFFF01, 0x00FFFF02,
0x00FF00FF, 0x00FF0000, 0x00FF0001, 0x00FF0002,
0x00FF01FF, 0x00FF0100, 0x00FF0101, 0x00FF0102,
0x00FF02FF, 0x00FF0200, 0x00FF0201, 0x00FF0202,
0x0000FFFF, 0x0000FF00, 0x0000FF01, 0x0000FF02,
0x000000FF, 0x00000000, 0x00000001, 0x00000002,
0x000001FF, 0x00000100, 0x00000101, 0x00000102,
0x000002FF, 0x00000200, 0x00000201, 0x00000202,
0x0001FFFF, 0x0001FF00, 0x0001FF01, 0x0001FF02,
0x000100FF, 0x00010000, 0x00010001, 0x00010002,
0x000101FF, 0x00010100, 0x00010101, 0x00010102,
0x000102FF, 0x00010200, 0x00010201, 0x00010202,
0x0002FFFF, 0x0002FF00, 0x0002FF01, 0x0002FF02,
0x000200FF, 0x00020000, 0x00020001, 0x00020002,
0x000201FF, 0x00020100, 0x00020101, 0x00020102,
0x000202FF, 0x00020200, 0x00020201, 0x00020202,
0x01FFFFFF, 0x01FFFF00, 0x01FFFF01, 0x01FFFF02,
0x01FF00FF, 0x01FF0000, 0x01FF0001, 0x01FF0002,
0x01FF01FF, 0x01FF0100, 0x01FF0101, 0x01FF0102,
0x01FF02FF, 0x01FF0200, 0x01FF0201, 0x01FF0202,
0x0100FFFF, 0x0100FF00, 0x0100FF01, 0x0100FF02,
0x010000FF, 0x01000000, 0x01000001, 0x01000002,
0x010001FF, 0x01000100, 0x01000101, 0x01000102,
0x010002FF, 0x01000200, 0x01000201, 0x01000202,
0x0101FFFF, 0x0101FF00, 0x0101FF01, 0x0101FF02,
0x010100FF, 0x01010000, 0x01010001, 0x01010002,
0x010101FF, 0x01010100, 0x01010101, 0x01010102,
0x010102FF, 0x01010200, 0x01010201, 0x01010202,
0x0102FFFF, 0x0102FF00, 0x0102FF01, 0x0102FF02,
0x010200FF, 0x01020000, 0x01020001, 0x01020002,
0x010201FF, 0x01020100, 0x01020101, 0x01020102,
0x010202FF, 0x01020200, 0x01020201, 0x01020202,
0x02FFFFFF, 0x02FFFF00, 0x02FFFF01, 0x02FFFF02,
0x02FF00FF, 0x02FF0000, 0x02FF0001, 0x02FF0002,
0x02FF01FF, 0x02FF0100, 0x02FF0101, 0x02FF0102,
0x02FF02FF, 0x02FF0200, 0x02FF0201, 0x02FF0202,
0x0200FFFF, 0x0200FF00, 0x0200FF01, 0x0200FF02,
0x020000FF, 0x02000000, 0x02000001, 0x02000002,
0x020001FF, 0x02000100, 0x02000101, 0x02000102,
0x020002FF, 0x02000200, 0x02000201, 0x02000202,
0x0201FFFF, 0x0201FF00, 0x0201FF01, 0x0201FF02,
0x020100FF, 0x02010000, 0x02010001, 0x02010002,
0x020101FF, 0x02010100, 0x02010101, 0x02010102,
0x020102FF, 0x02010200, 0x02010201, 0x02010202,
0x0202FFFF, 0x0202FF00, 0x0202FF01, 0x0202FF02,
0x020200FF, 0x02020000, 0x02020001, 0x02020002,
0x020201FF, 0x02020100, 0x02020101, 0x02020102,
0x020202FF, 0x02020200, 0x02020201, 0x02020202,
// zero_point = 2: entries 512-767
0xFEFEFEFE, 0xFEFEFEFF, 0xFEFEFE00, 0xFEFEFE01,
0xFEFEFFFE, 0xFEFEFFFF, 0xFEFEFF00, 0xFEFEFF01,
0xFEFE00FE, 0xFEFE00FF, 0xFEFE0000, 0xFEFE0001,
0xFEFE01FE, 0xFEFE01FF, 0xFEFE0100, 0xFEFE0101,
0xFEFFFEFE, 0xFEFFFEFF, 0xFEFFFE00, 0xFEFFFE01,
0xFEFFFFFE, 0xFEFFFFFF, 0xFEFFFF00, 0xFEFFFF01,
0xFEFF00FE, 0xFEFF00FF, 0xFEFF0000, 0xFEFF0001,
0xFEFF01FE, 0xFEFF01FF, 0xFEFF0100, 0xFEFF0101,
0xFE00FEFE, 0xFE00FEFF, 0xFE00FE00, 0xFE00FE01,
0xFE00FFFE, 0xFE00FFFF, 0xFE00FF00, 0xFE00FF01,
0xFE0000FE, 0xFE0000FF, 0xFE000000, 0xFE000001,
0xFE0001FE, 0xFE0001FF, 0xFE000100, 0xFE000101,
0xFE01FEFE, 0xFE01FEFF, 0xFE01FE00, 0xFE01FE01,
0xFE01FFFE, 0xFE01FFFF, 0xFE01FF00, 0xFE01FF01,
0xFE0100FE, 0xFE0100FF, 0xFE010000, 0xFE010001,
0xFE0101FE, 0xFE0101FF, 0xFE010100, 0xFE010101,
0xFFFEFEFE, 0xFFFEFEFF, 0xFFFEFE00, 0xFFFEFE01,
0xFFFEFFFE, 0xFFFEFFFF, 0xFFFEFF00, 0xFFFEFF01,
0xFFFE00FE, 0xFFFE00FF, 0xFFFE0000, 0xFFFE0001,
0xFFFE01FE, 0xFFFE01FF, 0xFFFE0100, 0xFFFE0101,
0xFFFFFEFE, 0xFFFFFEFF, 0xFFFFFE00, 0xFFFFFE01,
0xFFFFFFFE, 0xFFFFFFFF, 0xFFFFFF00, 0xFFFFFF01,
0xFFFF00FE, 0xFFFF00FF, 0xFFFF0000, 0xFFFF0001,
0xFFFF01FE, 0xFFFF01FF, 0xFFFF0100, 0xFFFF0101,
0xFF00FEFE, 0xFF00FEFF, 0xFF00FE00, 0xFF00FE01,
0xFF00FFFE, 0xFF00FFFF, 0xFF00FF00, 0xFF00FF01,
0xFF0000FE, 0xFF0000FF, 0xFF000000, 0xFF000001,
0xFF0001FE, 0xFF0001FF, 0xFF000100, 0xFF000101,
0xFF01FEFE, 0xFF01FEFF, 0xFF01FE00, 0xFF01FE01,
0xFF01FFFE, 0xFF01FFFF, 0xFF01FF00, 0xFF01FF01,
0xFF0100FE, 0xFF0100FF, 0xFF010000, 0xFF010001,
0xFF0101FE, 0xFF0101FF, 0xFF010100, 0xFF010101,
0x00FEFEFE, 0x00FEFEFF, 0x00FEFE00, 0x00FEFE01,
0x00FEFFFE, 0x00FEFFFF, 0x00FEFF00, 0x00FEFF01,
0x00FE00FE, 0x00FE00FF, 0x00FE0000, 0x00FE0001,
0x00FE01FE, 0x00FE01FF, 0x00FE0100, 0x00FE0101,
0x00FFFEFE, 0x00FFFEFF, 0x00FFFE00, 0x00FFFE01,
0x00FFFFFE, 0x00FFFFFF, 0x00FFFF00, 0x00FFFF01,
0x00FF00FE, 0x00FF00FF, 0x00FF0000, 0x00FF0001,
0x00FF01FE, 0x00FF01FF, 0x00FF0100, 0x00FF0101,
0x0000FEFE, 0x0000FEFF, 0x0000FE00, 0x0000FE01,
0x0000FFFE, 0x0000FFFF, 0x0000FF00, 0x0000FF01,
0x000000FE, 0x000000FF, 0x00000000, 0x00000001,
0x000001FE, 0x000001FF, 0x00000100, 0x00000101,
0x0001FEFE, 0x0001FEFF, 0x0001FE00, 0x0001FE01,
0x0001FFFE, 0x0001FFFF, 0x0001FF00, 0x0001FF01,
0x000100FE, 0x000100FF, 0x00010000, 0x00010001,
0x000101FE, 0x000101FF, 0x00010100, 0x00010101,
0x01FEFEFE, 0x01FEFEFF, 0x01FEFE00, 0x01FEFE01,
0x01FEFFFE, 0x01FEFFFF, 0x01FEFF00, 0x01FEFF01,
0x01FE00FE, 0x01FE00FF, 0x01FE0000, 0x01FE0001,
0x01FE01FE, 0x01FE01FF, 0x01FE0100, 0x01FE0101,
0x01FFFEFE, 0x01FFFEFF, 0x01FFFE00, 0x01FFFE01,
0x01FFFFFE, 0x01FFFFFF, 0x01FFFF00, 0x01FFFF01,
0x01FF00FE, 0x01FF00FF, 0x01FF0000, 0x01FF0001,
0x01FF01FE, 0x01FF01FF, 0x01FF0100, 0x01FF0101,
0x0100FEFE, 0x0100FEFF, 0x0100FE00, 0x0100FE01,
0x0100FFFE, 0x0100FFFF, 0x0100FF00, 0x0100FF01,
0x010000FE, 0x010000FF, 0x01000000, 0x01000001,
0x010001FE, 0x010001FF, 0x01000100, 0x01000101,
0x0101FEFE, 0x0101FEFF, 0x0101FE00, 0x0101FE01,
0x0101FFFE, 0x0101FFFF, 0x0101FF00, 0x0101FF01,
0x010100FE, 0x010100FF, 0x01010000, 0x01010001,
0x010101FE, 0x010101FF, 0x01010100, 0x01010101,
// zero_point = 3: entries 768-1023
0xFDFDFDFD, 0xFDFDFDFE, 0xFDFDFDFF, 0xFDFDFD00,
0xFDFDFEFD, 0xFDFDFEFE, 0xFDFDFEFF, 0xFDFDFE00,
0xFDFDFFFD, 0xFDFDFFFE, 0xFDFDFFFF, 0xFDFDFF00,
0xFDFD00FD, 0xFDFD00FE, 0xFDFD00FF, 0xFDFD0000,
0xFDFEFDFD, 0xFDFEFDFE, 0xFDFEFDFF, 0xFDFEFD00,
0xFDFEFEFD, 0xFDFEFEFE, 0xFDFEFEFF, 0xFDFEFE00,
0xFDFEFFFD, 0xFDFEFFFE, 0xFDFEFFFF, 0xFDFEFF00,
0xFDFE00FD, 0xFDFE00FE, 0xFDFE00FF, 0xFDFE0000,
0xFDFFFDFD, 0xFDFFFDFE, 0xFDFFFDFF, 0xFDFFFD00,
0xFDFFFEFD, 0xFDFFFEFE, 0xFDFFFEFF, 0xFDFFFE00,
0xFDFFFFFD, 0xFDFFFFFE, 0xFDFFFFFF, 0xFDFFFF00,
0xFDFF00FD, 0xFDFF00FE, 0xFDFF00FF, 0xFDFF0000,
0xFD00FDFD, 0xFD00FDFE, 0xFD00FDFF, 0xFD00FD00,
0xFD00FEFD, 0xFD00FEFE, 0xFD00FEFF, 0xFD00FE00,
0xFD00FFFD, 0xFD00FFFE, 0xFD00FFFF, 0xFD00FF00,
0xFD0000FD, 0xFD0000FE, 0xFD0000FF, 0xFD000000,
0xFEFDFDFD, 0xFEFDFDFE, 0xFEFDFDFF, 0xFEFDFD00,
0xFEFDFEFD, 0xFEFDFEFE, 0xFEFDFEFF, 0xFEFDFE00,
0xFEFDFFFD, 0xFEFDFFFE, 0xFEFDFFFF, 0xFEFDFF00,
0xFEFD00FD, 0xFEFD00FE, 0xFEFD00FF, 0xFEFD0000,
0xFEFEFDFD, 0xFEFEFDFE, 0xFEFEFDFF, 0xFEFEFD00,
0xFEFEFEFD, 0xFEFEFEFE, 0xFEFEFEFF, 0xFEFEFE00,
0xFEFEFFFD, 0xFEFEFFFE, 0xFEFEFFFF, 0xFEFEFF00,
0xFEFE00FD, 0xFEFE00FE, 0xFEFE00FF, 0xFEFE0000,
0xFEFFFDFD, 0xFEFFFDFE, 0xFEFFFDFF, 0xFEFFFD00,
0xFEFFFEFD, 0xFEFFFEFE, 0xFEFFFEFF, 0xFEFFFE00,
0xFEFFFFFD, 0xFEFFFFFE, 0xFEFFFFFF, 0xFEFFFF00,
0xFEFF00FD, 0xFEFF00FE, 0xFEFF00FF, 0xFEFF0000,
0xFE00FDFD, 0xFE00FDFE, 0xFE00FDFF, 0xFE00FD00,
0xFE00FEFD, 0xFE00FEFE, 0xFE00FEFF, 0xFE00FE00,
0xFE00FFFD, 0xFE00FFFE, 0xFE00FFFF, 0xFE00FF00,
0xFE0000FD, 0xFE0000FE, 0xFE0000FF, 0xFE000000,
0xFFFDFDFD, 0xFFFDFDFE, 0xFFFDFDFF, 0xFFFDFD00,
0xFFFDFEFD, 0xFFFDFEFE, 0xFFFDFEFF, 0xFFFDFE00,
0xFFFDFFFD, 0xFFFDFFFE, 0xFFFDFFFF, 0xFFFDFF00,
0xFFFD00FD, 0xFFFD00FE, 0xFFFD00FF, 0xFFFD0000,
0xFFFEFDFD, 0xFFFEFDFE, 0xFFFEFDFF, 0xFFFEFD00,
0xFFFEFEFD, 0xFFFEFEFE, 0xFFFEFEFF, 0xFFFEFE00,
0xFFFEFFFD, 0xFFFEFFFE, 0xFFFEFFFF, 0xFFFEFF00,
0xFFFE00FD, 0xFFFE00FE, 0xFFFE00FF, 0xFFFE0000,
0xFFFFFDFD, 0xFFFFFDFE, 0xFFFFFDFF, 0xFFFFFD00,
0xFFFFFEFD, 0xFFFFFEFE, 0xFFFFFEFF, 0xFFFFFE00,
0xFFFFFFFD, 0xFFFFFFFE, 0xFFFFFFFF, 0xFFFFFF00,
0xFFFF00FD, 0xFFFF00FE, 0xFFFF00FF, 0xFFFF0000,
0xFF00FDFD, 0xFF00FDFE, 0xFF00FDFF, 0xFF00FD00,
0xFF00FEFD, 0xFF00FEFE, 0xFF00FEFF, 0xFF00FE00,
0xFF00FFFD, 0xFF00FFFE, 0xFF00FFFF, 0xFF00FF00,
0xFF0000FD, 0xFF0000FE, 0xFF0000FF, 0xFF000000,
0x00FDFDFD, 0x00FDFDFE, 0x00FDFDFF, 0x00FDFD00,
0x00FDFEFD, 0x00FDFEFE, 0x00FDFEFF, 0x00FDFE00,
0x00FDFFFD, 0x00FDFFFE, 0x00FDFFFF, 0x00FDFF00,
0x00FD00FD, 0x00FD00FE, 0x00FD00FF, 0x00FD0000,
0x00FEFDFD, 0x00FEFDFE, 0x00FEFDFF, 0x00FEFD00,
0x00FEFEFD, 0x00FEFEFE, 0x00FEFEFF, 0x00FEFE00,
0x00FEFFFD, 0x00FEFFFE, 0x00FEFFFF, 0x00FEFF00,
0x00FE00FD, 0x00FE00FE, 0x00FE00FF, 0x00FE0000,
0x00FFFDFD, 0x00FFFDFE, 0x00FFFDFF, 0x00FFFD00,
0x00FFFEFD, 0x00FFFEFE, 0x00FFFEFF, 0x00FFFE00,
0x00FFFFFD, 0x00FFFFFE, 0x00FFFFFF, 0x00FFFF00,
0x00FF00FD, 0x00FF00FE, 0x00FF00FF, 0x00FF0000,
0x0000FDFD, 0x0000FDFE, 0x0000FDFF, 0x0000FD00,
0x0000FEFD, 0x0000FEFE, 0x0000FEFF, 0x0000FE00,
0x0000FFFD, 0x0000FFFE, 0x0000FFFF, 0x0000FF00,
0x000000FD, 0x000000FE, 0x000000FF, 0x00000000);
fn LoadDequantizationTable(local_idx:u32)
{
// Move dequantization table into on chip memory.
shm_dequantization_table[local_idx] = q2_dequantization_table[local_idx];
}
fn DequantizedFrom2BitsTo8Bits(in: u32, zero: i32) -> vec4<u32>
{
let base = u32(zero) * 256;
let unpacked = unpack4xU8(in);
return vec4<u32>(shm_dequantization_table[base + unpacked[0]],
shm_dequantization_table[base + unpacked[1]],
shm_dequantization_table[base + unpacked[2]],
shm_dequantization_table[base + unpacked[3]]);
}
#else
const lut_size = 256;
var<workgroup> shm_dequantization_table : array<u32, lut_size>;
const q2_dequantization_table = array<u32, lut_size>(
Expand Down Expand Up @@ -313,6 +594,7 @@
shm_dequantization_table[unpacked[3]]);
}
#endif
#endif

#if has_zero_points && n_bits == 8
// If has_zero_points is true, vec4<i32>(unpack4xU8(b_data)) - vec4<i32>(zero) may be out of the range [-128, 127] since zero can be any value between [0, 255].
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,21 @@ $MAIN {
#endif

#if n_bits == 2
#if has_zero_points
// The workgroup size is 128, LoadDequantizationTable needs to load 1024 entries.
LoadDequantizationTable(local_idx);
LoadDequantizationTable(local_idx + 128);
LoadDequantizationTable(local_idx + 256);
LoadDequantizationTable(local_idx + 384);
LoadDequantizationTable(local_idx + 512);
LoadDequantizationTable(local_idx + 640);
LoadDequantizationTable(local_idx + 768);
LoadDequantizationTable(local_idx + 896);
#else
// The workgroup size is 128, LoadDequantizationTable needs to be called twice.
LoadDequantizationTable(local_idx);
LoadDequantizationTable(local_idx+127);
LoadDequantizationTable(local_idx+128);
#endif
workgroupBarrier();
#endif
#if single_scale_weights
Expand Down Expand Up @@ -141,8 +153,13 @@ $MAIN {

#elif n_bits == 2
let b_value = b.getByOffset(b_offset);
#if has_zero_points
let own_b = DequantizedFrom2BitsTo8Bits(b_value.x, zero);
let own_b1 = DequantizedFrom2BitsTo8Bits(b_value.y, zero);
#else
let own_b = DequantizedFrom2BitsTo8Bits(b_value.x);
let own_b1 = DequantizedFrom2BitsTo8Bits(b_value.y);
#endif
inter_results[row_offset + local_row][local_col] += SDP8AI(own_a, own_b, own_a1, own_b1, own_scale_a * own_scale_b);
#endif
}
Expand Down
3 changes: 1 addition & 2 deletions onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc
Original file line number Diff line number Diff line change
Expand Up @@ -221,9 +221,8 @@ Status ApplyMatMulNBits(const Tensor* a, const Tensor* b, const Tensor* scales,
#endif

// On FP32 only GPUs, integer math is faster than FP32 therefore always use DP4A independent of length of M.
// DP4A Q2 path uses a hardcoded LUT with zero_point=2, so skip DP4A for Q2 with custom zero points.
// DP4A Q2 path now supports custom zero points via a 1024-entry LUT (4 zero-point sections × 256 byte values).
if ((M >= kMinMForTileOptimization || y->DataType() == DataTypeImpl::GetType<float>() || context.AdapterInfo().vendor == std::string_view{"qualcomm"}) &&
!(has_zero_points && nbits == 2) &&
CanApplyDP4AMatrixMatMulNBits(context, accuracy_level, block_size, N, K, components_a)) {
return ApplyDP4AMatrixMatMulNBits(a, b, scales, zero_points, bias, batch_count, M, N, K, block_size, zero_blocks_per_col, kMinMForTileOptimization, static_cast<uint32_t>(nbits), context, y, weight_index);
}
Expand Down
Loading
Loading