diff --git a/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul.wgsl.template b/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul.wgsl.template index 6d22e6743707b..e581d0b63cd9d 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul.wgsl.template +++ b/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul.wgsl.template @@ -140,8 +140,13 @@ fn loadSHMA(batch:u32, a_global_base:u32, kidx_v:u32, row: u32, col: u32) const b_weight_offset : u32 = 0; let b_value = b.getByOffset(b_global * uniforms.K16 + kidx_v + col); #endif - tile_B[col][row] = DequantizedFrom2BitsTo8Bits(b_value); let block_idx = kidx_v/(block_size/16); +#if has_zero_points + let zero = mm_read_zero(b_global, block_idx, uniforms.N, uniforms.zero_blocks_per_col); + tile_B[col][row] = DequantizedFrom2BitsTo8Bits(b_value, zero); +#else + tile_B[col][row] = DequantizedFrom2BitsTo8Bits(b_value); +#endif let b_scale_offset = uniforms.weight_idx * uniforms.N * (uniforms.K/block_size); scale_B[row] = scales_b.getByOffset(b_scale_offset + b_global*(uniforms.K/block_size) + block_idx); } @@ -150,6 +155,11 @@ fn loadSHMA(batch:u32, a_global_base:u32, kidx_v:u32, row: u32, col: u32) $MAIN { #if n_bits == 2 LoadDequantizationTable(local_idx); +#if has_zero_points + LoadDequantizationTable(local_idx + 256); + LoadDequantizationTable(local_idx + 512); + LoadDequantizationTable(local_idx + 768); +#endif workgroupBarrier(); #endif // During the load phase we use all 256 threads to load 64 rows of A/B. diff --git a/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_common.wgsl.template b/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_common.wgsl.template index 186685b9a8dd4..3aa010cb64a2d 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_common.wgsl.template +++ b/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_common.wgsl.template @@ -40,6 +40,287 @@ #if n_bits == 2 alias mul_precision = output_element_t; +#if has_zero_points + const lut_size = 1024; + var shm_dequantization_table : array; + // 1024-entry LUT: 4 sections of 256 entries, one per zero_point value (0-3). + // Index as: zero * 256 + byte_value + const q2_dequantization_table = array( + // zero_point = 0: entries 0-255 + 0x00000000, 0x00000001, 0x00000002, 0x00000003, + 0x00000100, 0x00000101, 0x00000102, 0x00000103, + 0x00000200, 0x00000201, 0x00000202, 0x00000203, + 0x00000300, 0x00000301, 0x00000302, 0x00000303, + 0x00010000, 0x00010001, 0x00010002, 0x00010003, + 0x00010100, 0x00010101, 0x00010102, 0x00010103, + 0x00010200, 0x00010201, 0x00010202, 0x00010203, + 0x00010300, 0x00010301, 0x00010302, 0x00010303, + 0x00020000, 0x00020001, 0x00020002, 0x00020003, + 0x00020100, 0x00020101, 0x00020102, 0x00020103, + 0x00020200, 0x00020201, 0x00020202, 0x00020203, + 0x00020300, 0x00020301, 0x00020302, 0x00020303, + 0x00030000, 0x00030001, 0x00030002, 0x00030003, + 0x00030100, 0x00030101, 0x00030102, 0x00030103, + 0x00030200, 0x00030201, 0x00030202, 0x00030203, + 0x00030300, 0x00030301, 0x00030302, 0x00030303, + 0x01000000, 0x01000001, 0x01000002, 0x01000003, + 0x01000100, 0x01000101, 0x01000102, 0x01000103, + 0x01000200, 0x01000201, 0x01000202, 0x01000203, + 0x01000300, 0x01000301, 0x01000302, 0x01000303, + 0x01010000, 0x01010001, 0x01010002, 0x01010003, + 0x01010100, 0x01010101, 0x01010102, 0x01010103, + 0x01010200, 0x01010201, 0x01010202, 0x01010203, + 0x01010300, 0x01010301, 0x01010302, 0x01010303, + 0x01020000, 0x01020001, 0x01020002, 0x01020003, + 0x01020100, 0x01020101, 0x01020102, 0x01020103, + 0x01020200, 0x01020201, 0x01020202, 0x01020203, + 0x01020300, 0x01020301, 0x01020302, 0x01020303, + 0x01030000, 0x01030001, 0x01030002, 0x01030003, + 0x01030100, 0x01030101, 0x01030102, 0x01030103, + 0x01030200, 0x01030201, 0x01030202, 0x01030203, + 0x01030300, 0x01030301, 0x01030302, 0x01030303, + 0x02000000, 0x02000001, 0x02000002, 0x02000003, + 0x02000100, 0x02000101, 0x02000102, 0x02000103, + 0x02000200, 0x02000201, 0x02000202, 0x02000203, + 0x02000300, 0x02000301, 0x02000302, 0x02000303, + 0x02010000, 0x02010001, 0x02010002, 0x02010003, + 0x02010100, 0x02010101, 0x02010102, 0x02010103, + 0x02010200, 0x02010201, 0x02010202, 0x02010203, + 0x02010300, 0x02010301, 0x02010302, 0x02010303, + 0x02020000, 0x02020001, 0x02020002, 0x02020003, + 0x02020100, 0x02020101, 0x02020102, 0x02020103, + 0x02020200, 0x02020201, 0x02020202, 0x02020203, + 0x02020300, 0x02020301, 0x02020302, 0x02020303, + 0x02030000, 0x02030001, 0x02030002, 0x02030003, + 0x02030100, 0x02030101, 0x02030102, 0x02030103, + 0x02030200, 0x02030201, 0x02030202, 0x02030203, + 0x02030300, 0x02030301, 0x02030302, 0x02030303, + 0x03000000, 0x03000001, 0x03000002, 0x03000003, + 0x03000100, 0x03000101, 0x03000102, 0x03000103, + 0x03000200, 0x03000201, 0x03000202, 0x03000203, + 0x03000300, 0x03000301, 0x03000302, 0x03000303, + 0x03010000, 0x03010001, 0x03010002, 0x03010003, + 0x03010100, 0x03010101, 0x03010102, 0x03010103, + 0x03010200, 0x03010201, 0x03010202, 0x03010203, + 0x03010300, 0x03010301, 0x03010302, 0x03010303, + 0x03020000, 0x03020001, 0x03020002, 0x03020003, + 0x03020100, 0x03020101, 0x03020102, 0x03020103, + 0x03020200, 0x03020201, 0x03020202, 0x03020203, + 0x03020300, 0x03020301, 0x03020302, 0x03020303, + 0x03030000, 0x03030001, 0x03030002, 0x03030003, + 0x03030100, 0x03030101, 0x03030102, 0x03030103, + 0x03030200, 0x03030201, 0x03030202, 0x03030203, + 0x03030300, 0x03030301, 0x03030302, 0x03030303, + // zero_point = 1: entries 256-511 + 0xFFFFFFFF, 0xFFFFFF00, 0xFFFFFF01, 0xFFFFFF02, + 0xFFFF00FF, 0xFFFF0000, 0xFFFF0001, 0xFFFF0002, + 0xFFFF01FF, 0xFFFF0100, 0xFFFF0101, 0xFFFF0102, + 0xFFFF02FF, 0xFFFF0200, 0xFFFF0201, 0xFFFF0202, + 0xFF00FFFF, 0xFF00FF00, 0xFF00FF01, 0xFF00FF02, + 0xFF0000FF, 0xFF000000, 0xFF000001, 0xFF000002, + 0xFF0001FF, 0xFF000100, 0xFF000101, 0xFF000102, + 0xFF0002FF, 0xFF000200, 0xFF000201, 0xFF000202, + 0xFF01FFFF, 0xFF01FF00, 0xFF01FF01, 0xFF01FF02, + 0xFF0100FF, 0xFF010000, 0xFF010001, 0xFF010002, + 0xFF0101FF, 0xFF010100, 0xFF010101, 0xFF010102, + 0xFF0102FF, 0xFF010200, 0xFF010201, 0xFF010202, + 0xFF02FFFF, 0xFF02FF00, 0xFF02FF01, 0xFF02FF02, + 0xFF0200FF, 0xFF020000, 0xFF020001, 0xFF020002, + 0xFF0201FF, 0xFF020100, 0xFF020101, 0xFF020102, + 0xFF0202FF, 0xFF020200, 0xFF020201, 0xFF020202, + 0x00FFFFFF, 0x00FFFF00, 0x00FFFF01, 0x00FFFF02, + 0x00FF00FF, 0x00FF0000, 0x00FF0001, 0x00FF0002, + 0x00FF01FF, 0x00FF0100, 0x00FF0101, 0x00FF0102, + 0x00FF02FF, 0x00FF0200, 0x00FF0201, 0x00FF0202, + 0x0000FFFF, 0x0000FF00, 0x0000FF01, 0x0000FF02, + 0x000000FF, 0x00000000, 0x00000001, 0x00000002, + 0x000001FF, 0x00000100, 0x00000101, 0x00000102, + 0x000002FF, 0x00000200, 0x00000201, 0x00000202, + 0x0001FFFF, 0x0001FF00, 0x0001FF01, 0x0001FF02, + 0x000100FF, 0x00010000, 0x00010001, 0x00010002, + 0x000101FF, 0x00010100, 0x00010101, 0x00010102, + 0x000102FF, 0x00010200, 0x00010201, 0x00010202, + 0x0002FFFF, 0x0002FF00, 0x0002FF01, 0x0002FF02, + 0x000200FF, 0x00020000, 0x00020001, 0x00020002, + 0x000201FF, 0x00020100, 0x00020101, 0x00020102, + 0x000202FF, 0x00020200, 0x00020201, 0x00020202, + 0x01FFFFFF, 0x01FFFF00, 0x01FFFF01, 0x01FFFF02, + 0x01FF00FF, 0x01FF0000, 0x01FF0001, 0x01FF0002, + 0x01FF01FF, 0x01FF0100, 0x01FF0101, 0x01FF0102, + 0x01FF02FF, 0x01FF0200, 0x01FF0201, 0x01FF0202, + 0x0100FFFF, 0x0100FF00, 0x0100FF01, 0x0100FF02, + 0x010000FF, 0x01000000, 0x01000001, 0x01000002, + 0x010001FF, 0x01000100, 0x01000101, 0x01000102, + 0x010002FF, 0x01000200, 0x01000201, 0x01000202, + 0x0101FFFF, 0x0101FF00, 0x0101FF01, 0x0101FF02, + 0x010100FF, 0x01010000, 0x01010001, 0x01010002, + 0x010101FF, 0x01010100, 0x01010101, 0x01010102, + 0x010102FF, 0x01010200, 0x01010201, 0x01010202, + 0x0102FFFF, 0x0102FF00, 0x0102FF01, 0x0102FF02, + 0x010200FF, 0x01020000, 0x01020001, 0x01020002, + 0x010201FF, 0x01020100, 0x01020101, 0x01020102, + 0x010202FF, 0x01020200, 0x01020201, 0x01020202, + 0x02FFFFFF, 0x02FFFF00, 0x02FFFF01, 0x02FFFF02, + 0x02FF00FF, 0x02FF0000, 0x02FF0001, 0x02FF0002, + 0x02FF01FF, 0x02FF0100, 0x02FF0101, 0x02FF0102, + 0x02FF02FF, 0x02FF0200, 0x02FF0201, 0x02FF0202, + 0x0200FFFF, 0x0200FF00, 0x0200FF01, 0x0200FF02, + 0x020000FF, 0x02000000, 0x02000001, 0x02000002, + 0x020001FF, 0x02000100, 0x02000101, 0x02000102, + 0x020002FF, 0x02000200, 0x02000201, 0x02000202, + 0x0201FFFF, 0x0201FF00, 0x0201FF01, 0x0201FF02, + 0x020100FF, 0x02010000, 0x02010001, 0x02010002, + 0x020101FF, 0x02010100, 0x02010101, 0x02010102, + 0x020102FF, 0x02010200, 0x02010201, 0x02010202, + 0x0202FFFF, 0x0202FF00, 0x0202FF01, 0x0202FF02, + 0x020200FF, 0x02020000, 0x02020001, 0x02020002, + 0x020201FF, 0x02020100, 0x02020101, 0x02020102, + 0x020202FF, 0x02020200, 0x02020201, 0x02020202, + // zero_point = 2: entries 512-767 + 0xFEFEFEFE, 0xFEFEFEFF, 0xFEFEFE00, 0xFEFEFE01, + 0xFEFEFFFE, 0xFEFEFFFF, 0xFEFEFF00, 0xFEFEFF01, + 0xFEFE00FE, 0xFEFE00FF, 0xFEFE0000, 0xFEFE0001, + 0xFEFE01FE, 0xFEFE01FF, 0xFEFE0100, 0xFEFE0101, + 0xFEFFFEFE, 0xFEFFFEFF, 0xFEFFFE00, 0xFEFFFE01, + 0xFEFFFFFE, 0xFEFFFFFF, 0xFEFFFF00, 0xFEFFFF01, + 0xFEFF00FE, 0xFEFF00FF, 0xFEFF0000, 0xFEFF0001, + 0xFEFF01FE, 0xFEFF01FF, 0xFEFF0100, 0xFEFF0101, + 0xFE00FEFE, 0xFE00FEFF, 0xFE00FE00, 0xFE00FE01, + 0xFE00FFFE, 0xFE00FFFF, 0xFE00FF00, 0xFE00FF01, + 0xFE0000FE, 0xFE0000FF, 0xFE000000, 0xFE000001, + 0xFE0001FE, 0xFE0001FF, 0xFE000100, 0xFE000101, + 0xFE01FEFE, 0xFE01FEFF, 0xFE01FE00, 0xFE01FE01, + 0xFE01FFFE, 0xFE01FFFF, 0xFE01FF00, 0xFE01FF01, + 0xFE0100FE, 0xFE0100FF, 0xFE010000, 0xFE010001, + 0xFE0101FE, 0xFE0101FF, 0xFE010100, 0xFE010101, + 0xFFFEFEFE, 0xFFFEFEFF, 0xFFFEFE00, 0xFFFEFE01, + 0xFFFEFFFE, 0xFFFEFFFF, 0xFFFEFF00, 0xFFFEFF01, + 0xFFFE00FE, 0xFFFE00FF, 0xFFFE0000, 0xFFFE0001, + 0xFFFE01FE, 0xFFFE01FF, 0xFFFE0100, 0xFFFE0101, + 0xFFFFFEFE, 0xFFFFFEFF, 0xFFFFFE00, 0xFFFFFE01, + 0xFFFFFFFE, 0xFFFFFFFF, 0xFFFFFF00, 0xFFFFFF01, + 0xFFFF00FE, 0xFFFF00FF, 0xFFFF0000, 0xFFFF0001, + 0xFFFF01FE, 0xFFFF01FF, 0xFFFF0100, 0xFFFF0101, + 0xFF00FEFE, 0xFF00FEFF, 0xFF00FE00, 0xFF00FE01, + 0xFF00FFFE, 0xFF00FFFF, 0xFF00FF00, 0xFF00FF01, + 0xFF0000FE, 0xFF0000FF, 0xFF000000, 0xFF000001, + 0xFF0001FE, 0xFF0001FF, 0xFF000100, 0xFF000101, + 0xFF01FEFE, 0xFF01FEFF, 0xFF01FE00, 0xFF01FE01, + 0xFF01FFFE, 0xFF01FFFF, 0xFF01FF00, 0xFF01FF01, + 0xFF0100FE, 0xFF0100FF, 0xFF010000, 0xFF010001, + 0xFF0101FE, 0xFF0101FF, 0xFF010100, 0xFF010101, + 0x00FEFEFE, 0x00FEFEFF, 0x00FEFE00, 0x00FEFE01, + 0x00FEFFFE, 0x00FEFFFF, 0x00FEFF00, 0x00FEFF01, + 0x00FE00FE, 0x00FE00FF, 0x00FE0000, 0x00FE0001, + 0x00FE01FE, 0x00FE01FF, 0x00FE0100, 0x00FE0101, + 0x00FFFEFE, 0x00FFFEFF, 0x00FFFE00, 0x00FFFE01, + 0x00FFFFFE, 0x00FFFFFF, 0x00FFFF00, 0x00FFFF01, + 0x00FF00FE, 0x00FF00FF, 0x00FF0000, 0x00FF0001, + 0x00FF01FE, 0x00FF01FF, 0x00FF0100, 0x00FF0101, + 0x0000FEFE, 0x0000FEFF, 0x0000FE00, 0x0000FE01, + 0x0000FFFE, 0x0000FFFF, 0x0000FF00, 0x0000FF01, + 0x000000FE, 0x000000FF, 0x00000000, 0x00000001, + 0x000001FE, 0x000001FF, 0x00000100, 0x00000101, + 0x0001FEFE, 0x0001FEFF, 0x0001FE00, 0x0001FE01, + 0x0001FFFE, 0x0001FFFF, 0x0001FF00, 0x0001FF01, + 0x000100FE, 0x000100FF, 0x00010000, 0x00010001, + 0x000101FE, 0x000101FF, 0x00010100, 0x00010101, + 0x01FEFEFE, 0x01FEFEFF, 0x01FEFE00, 0x01FEFE01, + 0x01FEFFFE, 0x01FEFFFF, 0x01FEFF00, 0x01FEFF01, + 0x01FE00FE, 0x01FE00FF, 0x01FE0000, 0x01FE0001, + 0x01FE01FE, 0x01FE01FF, 0x01FE0100, 0x01FE0101, + 0x01FFFEFE, 0x01FFFEFF, 0x01FFFE00, 0x01FFFE01, + 0x01FFFFFE, 0x01FFFFFF, 0x01FFFF00, 0x01FFFF01, + 0x01FF00FE, 0x01FF00FF, 0x01FF0000, 0x01FF0001, + 0x01FF01FE, 0x01FF01FF, 0x01FF0100, 0x01FF0101, + 0x0100FEFE, 0x0100FEFF, 0x0100FE00, 0x0100FE01, + 0x0100FFFE, 0x0100FFFF, 0x0100FF00, 0x0100FF01, + 0x010000FE, 0x010000FF, 0x01000000, 0x01000001, + 0x010001FE, 0x010001FF, 0x01000100, 0x01000101, + 0x0101FEFE, 0x0101FEFF, 0x0101FE00, 0x0101FE01, + 0x0101FFFE, 0x0101FFFF, 0x0101FF00, 0x0101FF01, + 0x010100FE, 0x010100FF, 0x01010000, 0x01010001, + 0x010101FE, 0x010101FF, 0x01010100, 0x01010101, + // zero_point = 3: entries 768-1023 + 0xFDFDFDFD, 0xFDFDFDFE, 0xFDFDFDFF, 0xFDFDFD00, + 0xFDFDFEFD, 0xFDFDFEFE, 0xFDFDFEFF, 0xFDFDFE00, + 0xFDFDFFFD, 0xFDFDFFFE, 0xFDFDFFFF, 0xFDFDFF00, + 0xFDFD00FD, 0xFDFD00FE, 0xFDFD00FF, 0xFDFD0000, + 0xFDFEFDFD, 0xFDFEFDFE, 0xFDFEFDFF, 0xFDFEFD00, + 0xFDFEFEFD, 0xFDFEFEFE, 0xFDFEFEFF, 0xFDFEFE00, + 0xFDFEFFFD, 0xFDFEFFFE, 0xFDFEFFFF, 0xFDFEFF00, + 0xFDFE00FD, 0xFDFE00FE, 0xFDFE00FF, 0xFDFE0000, + 0xFDFFFDFD, 0xFDFFFDFE, 0xFDFFFDFF, 0xFDFFFD00, + 0xFDFFFEFD, 0xFDFFFEFE, 0xFDFFFEFF, 0xFDFFFE00, + 0xFDFFFFFD, 0xFDFFFFFE, 0xFDFFFFFF, 0xFDFFFF00, + 0xFDFF00FD, 0xFDFF00FE, 0xFDFF00FF, 0xFDFF0000, + 0xFD00FDFD, 0xFD00FDFE, 0xFD00FDFF, 0xFD00FD00, + 0xFD00FEFD, 0xFD00FEFE, 0xFD00FEFF, 0xFD00FE00, + 0xFD00FFFD, 0xFD00FFFE, 0xFD00FFFF, 0xFD00FF00, + 0xFD0000FD, 0xFD0000FE, 0xFD0000FF, 0xFD000000, + 0xFEFDFDFD, 0xFEFDFDFE, 0xFEFDFDFF, 0xFEFDFD00, + 0xFEFDFEFD, 0xFEFDFEFE, 0xFEFDFEFF, 0xFEFDFE00, + 0xFEFDFFFD, 0xFEFDFFFE, 0xFEFDFFFF, 0xFEFDFF00, + 0xFEFD00FD, 0xFEFD00FE, 0xFEFD00FF, 0xFEFD0000, + 0xFEFEFDFD, 0xFEFEFDFE, 0xFEFEFDFF, 0xFEFEFD00, + 0xFEFEFEFD, 0xFEFEFEFE, 0xFEFEFEFF, 0xFEFEFE00, + 0xFEFEFFFD, 0xFEFEFFFE, 0xFEFEFFFF, 0xFEFEFF00, + 0xFEFE00FD, 0xFEFE00FE, 0xFEFE00FF, 0xFEFE0000, + 0xFEFFFDFD, 0xFEFFFDFE, 0xFEFFFDFF, 0xFEFFFD00, + 0xFEFFFEFD, 0xFEFFFEFE, 0xFEFFFEFF, 0xFEFFFE00, + 0xFEFFFFFD, 0xFEFFFFFE, 0xFEFFFFFF, 0xFEFFFF00, + 0xFEFF00FD, 0xFEFF00FE, 0xFEFF00FF, 0xFEFF0000, + 0xFE00FDFD, 0xFE00FDFE, 0xFE00FDFF, 0xFE00FD00, + 0xFE00FEFD, 0xFE00FEFE, 0xFE00FEFF, 0xFE00FE00, + 0xFE00FFFD, 0xFE00FFFE, 0xFE00FFFF, 0xFE00FF00, + 0xFE0000FD, 0xFE0000FE, 0xFE0000FF, 0xFE000000, + 0xFFFDFDFD, 0xFFFDFDFE, 0xFFFDFDFF, 0xFFFDFD00, + 0xFFFDFEFD, 0xFFFDFEFE, 0xFFFDFEFF, 0xFFFDFE00, + 0xFFFDFFFD, 0xFFFDFFFE, 0xFFFDFFFF, 0xFFFDFF00, + 0xFFFD00FD, 0xFFFD00FE, 0xFFFD00FF, 0xFFFD0000, + 0xFFFEFDFD, 0xFFFEFDFE, 0xFFFEFDFF, 0xFFFEFD00, + 0xFFFEFEFD, 0xFFFEFEFE, 0xFFFEFEFF, 0xFFFEFE00, + 0xFFFEFFFD, 0xFFFEFFFE, 0xFFFEFFFF, 0xFFFEFF00, + 0xFFFE00FD, 0xFFFE00FE, 0xFFFE00FF, 0xFFFE0000, + 0xFFFFFDFD, 0xFFFFFDFE, 0xFFFFFDFF, 0xFFFFFD00, + 0xFFFFFEFD, 0xFFFFFEFE, 0xFFFFFEFF, 0xFFFFFE00, + 0xFFFFFFFD, 0xFFFFFFFE, 0xFFFFFFFF, 0xFFFFFF00, + 0xFFFF00FD, 0xFFFF00FE, 0xFFFF00FF, 0xFFFF0000, + 0xFF00FDFD, 0xFF00FDFE, 0xFF00FDFF, 0xFF00FD00, + 0xFF00FEFD, 0xFF00FEFE, 0xFF00FEFF, 0xFF00FE00, + 0xFF00FFFD, 0xFF00FFFE, 0xFF00FFFF, 0xFF00FF00, + 0xFF0000FD, 0xFF0000FE, 0xFF0000FF, 0xFF000000, + 0x00FDFDFD, 0x00FDFDFE, 0x00FDFDFF, 0x00FDFD00, + 0x00FDFEFD, 0x00FDFEFE, 0x00FDFEFF, 0x00FDFE00, + 0x00FDFFFD, 0x00FDFFFE, 0x00FDFFFF, 0x00FDFF00, + 0x00FD00FD, 0x00FD00FE, 0x00FD00FF, 0x00FD0000, + 0x00FEFDFD, 0x00FEFDFE, 0x00FEFDFF, 0x00FEFD00, + 0x00FEFEFD, 0x00FEFEFE, 0x00FEFEFF, 0x00FEFE00, + 0x00FEFFFD, 0x00FEFFFE, 0x00FEFFFF, 0x00FEFF00, + 0x00FE00FD, 0x00FE00FE, 0x00FE00FF, 0x00FE0000, + 0x00FFFDFD, 0x00FFFDFE, 0x00FFFDFF, 0x00FFFD00, + 0x00FFFEFD, 0x00FFFEFE, 0x00FFFEFF, 0x00FFFE00, + 0x00FFFFFD, 0x00FFFFFE, 0x00FFFFFF, 0x00FFFF00, + 0x00FF00FD, 0x00FF00FE, 0x00FF00FF, 0x00FF0000, + 0x0000FDFD, 0x0000FDFE, 0x0000FDFF, 0x0000FD00, + 0x0000FEFD, 0x0000FEFE, 0x0000FEFF, 0x0000FE00, + 0x0000FFFD, 0x0000FFFE, 0x0000FFFF, 0x0000FF00, + 0x000000FD, 0x000000FE, 0x000000FF, 0x00000000); + fn LoadDequantizationTable(local_idx:u32) + { + // Move dequantization table into on chip memory. + shm_dequantization_table[local_idx] = q2_dequantization_table[local_idx]; + } + fn DequantizedFrom2BitsTo8Bits(in: u32, zero: i32) -> vec4 + { + let base = u32(zero) * 256; + let unpacked = unpack4xU8(in); + return vec4(shm_dequantization_table[base + unpacked[0]], + shm_dequantization_table[base + unpacked[1]], + shm_dequantization_table[base + unpacked[2]], + shm_dequantization_table[base + unpacked[3]]); + } +#else const lut_size = 256; var shm_dequantization_table : array; const q2_dequantization_table = array( @@ -313,6 +594,7 @@ shm_dequantization_table[unpacked[3]]); } #endif +#endif #if has_zero_points && n_bits == 8 // If has_zero_points is true, vec4(unpack4xU8(b_data)) - vec4(zero) may be out of the range [-128, 127] since zero can be any value between [0, 255]. diff --git a/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_small_m.wgsl.template b/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_small_m.wgsl.template index 0cf1fd6e39d70..e7f6aa450b407 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_small_m.wgsl.template +++ b/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_small_m.wgsl.template @@ -80,9 +80,21 @@ $MAIN { #endif #if n_bits == 2 +#if has_zero_points + // The workgroup size is 128, LoadDequantizationTable needs to load 1024 entries. + LoadDequantizationTable(local_idx); + LoadDequantizationTable(local_idx + 128); + LoadDequantizationTable(local_idx + 256); + LoadDequantizationTable(local_idx + 384); + LoadDequantizationTable(local_idx + 512); + LoadDequantizationTable(local_idx + 640); + LoadDequantizationTable(local_idx + 768); + LoadDequantizationTable(local_idx + 896); +#else // The workgroup size is 128, LoadDequantizationTable needs to be called twice. LoadDequantizationTable(local_idx); - LoadDequantizationTable(local_idx+127); + LoadDequantizationTable(local_idx+128); +#endif workgroupBarrier(); #endif #if single_scale_weights @@ -141,8 +153,13 @@ $MAIN { #elif n_bits == 2 let b_value = b.getByOffset(b_offset); +#if has_zero_points + let own_b = DequantizedFrom2BitsTo8Bits(b_value.x, zero); + let own_b1 = DequantizedFrom2BitsTo8Bits(b_value.y, zero); +#else let own_b = DequantizedFrom2BitsTo8Bits(b_value.x); let own_b1 = DequantizedFrom2BitsTo8Bits(b_value.y); +#endif inter_results[row_offset + local_row][local_col] += SDP8AI(own_a, own_b, own_a1, own_b1, own_scale_a * own_scale_b); #endif } diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc index a2b89dfcee4d6..cb5654509fe8a 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc @@ -221,9 +221,8 @@ Status ApplyMatMulNBits(const Tensor* a, const Tensor* b, const Tensor* scales, #endif // On FP32 only GPUs, integer math is faster than FP32 therefore always use DP4A independent of length of M. - // DP4A Q2 path uses a hardcoded LUT with zero_point=2, so skip DP4A for Q2 with custom zero points. + // DP4A Q2 path now supports custom zero points via a 1024-entry LUT (4 zero-point sections × 256 byte values). if ((M >= kMinMForTileOptimization || y->DataType() == DataTypeImpl::GetType() || context.AdapterInfo().vendor == std::string_view{"qualcomm"}) && - !(has_zero_points && nbits == 2) && CanApplyDP4AMatrixMatMulNBits(context, accuracy_level, block_size, N, K, components_a)) { return ApplyDP4AMatrixMatMulNBits(a, b, scales, zero_points, bias, batch_count, M, N, K, block_size, zero_blocks_per_col, kMinMForTileOptimization, static_cast(nbits), context, y, weight_index); } diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_common.cc b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_common.cc index 58b0f3ada9341..b9eafeb43c7b6 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_common.cc +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_common.cc @@ -4,6 +4,7 @@ #include "contrib_ops/webgpu/quantization/matmul_nbits_common.h" #include #include "core/common/common.h" +#include "core/providers/webgpu/webgpu_context.h" namespace onnxruntime { namespace contrib { @@ -54,6 +55,12 @@ fn mm_read_zero(row : u32, col : u32, r_dim: u32, c_dim: u32) -> )" return ss.str(); } +bool HasDP4ADeviceSupport(int context_id) { + auto& ctx = onnxruntime::webgpu::WebGpuContextFactory::GetContext(context_id); + return ctx.DeviceHasFeature(wgpu::FeatureName::Subgroups) && + ctx.AdapterInfo().vendor != std::string_view{"apple"}; +} + } // namespace webgpu } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_common.h b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_common.h index dde18e8a78fd2..3db7c722b11eb 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_common.h +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_common.h @@ -21,6 +21,11 @@ namespace webgpu { std::string GenerateZeroPointReadingCode(uint32_t nbits, bool has_zero_points, const std::string& output_type = "output_element_t"); +/// Returns true when the default WebGPU device supports the DP4A kernel path +/// (Subgroups feature present and non-Apple vendor). +/// \p context_id is the WebGpuContext slot (0 for the default context). +bool HasDP4ADeviceSupport(int context_id = 0); + } // namespace webgpu } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/test/contrib_ops/matmul_2bits_test.cc b/onnxruntime/test/contrib_ops/matmul_2bits_test.cc index c2685c69db877..00ea8947b9dd7 100644 --- a/onnxruntime/test/contrib_ops/matmul_2bits_test.cc +++ b/onnxruntime/test/contrib_ops/matmul_2bits_test.cc @@ -27,6 +27,9 @@ #include "core/session/ort_env.h" #include "core/util/qmath.h" #include "core/providers/webgpu/webgpu_provider_options.h" +#ifdef USE_WEBGPU +#include "contrib_ops/webgpu/quantization/matmul_nbits_common.h" +#endif extern std::unique_ptr ort_env; @@ -540,6 +543,57 @@ TEST(MatMul2BitsWebGpu, Float32_ZeroPoint_LargerK) { RunWebGpu2BitsTest(1, 32, 256, 32, true, 0.3f, 0.05f); } +// DP4A path tests (accuracy_level=4) — exercises the 1024-entry LUT / dequantization +// path for 2-bit weights with zero_points. +// DP4A constraints: accuracy_level==4, block_size%32==0, K%128==0, N%16==0. +// Skipped when the adapter lacks Subgroups support or is Apple (Metal), +// because the DP4A kernel would silently fall back to the default path. +TEST(MatMul2BitsWebGpu, Float32_ZeroPoint_DP4A) { + // Ensure the WebGPU context is initialized so we can query adapter capabilities. + auto ep = DefaultWebGpuExecutionProvider(); + if (!contrib::webgpu::HasDP4ADeviceSupport(ep->GetDeviceId())) { + GTEST_SKIP() << "DP4A requires Subgroups support on a non-Apple adapter"; + } + + TestOptions2Bits opts{}; + opts.accuracy_level = 4; + opts.has_zero_point = true; + opts.output_abs_error = 0.1f; + opts.output_rel_error = 0.02f; + + // M=1, N=16, K=128, block_size=32 — minimal DP4A-eligible shape + opts.M = 1; + opts.N = 16; + opts.K = 128; + opts.block_size = 32; + RunTest2Bits(opts); + + // M=1, N=32, K=256, block_size=32 — larger K + opts.M = 1; + opts.N = 32; + opts.K = 256; + opts.block_size = 32; + opts.output_abs_error = 0.3f; + opts.output_rel_error = 0.05f; + RunTest2Bits(opts); + + // M=4 (rows), N=32, K=128, block_size=32 + opts.M = 4; + opts.N = 32; + opts.K = 128; + opts.block_size = 32; + opts.output_abs_error = 0.1f; + opts.output_rel_error = 0.02f; + RunTest2Bits(opts); + + // M=1, N=16, K=128, block_size=128 — full-block + opts.M = 1; + opts.N = 16; + opts.K = 128; + opts.block_size = 128; + RunTest2Bits(opts); +} + #endif // USE_WEBGPU } // namespace test