Add zero point support to dp4a 2-bit dequantization in the WebGPU MatMulNbits#27325
Add zero point support to dp4a 2-bit dequantization in the WebGPU MatMulNbits#27325HectorSVC wants to merge 4 commits intomicrosoft:mainfrom
Conversation
…MulNBits kernel. Previously, the dp4a path for 2-bit quantization used a hardcoded 256-entry LUT assuming zero_point=2, and was blocked from running when custom zero points were provided.
|
reflect the review comments in PR: #27285 |
|
/azp run Linux QNN CI Pipeline,Win_TRT_Minimal_CUDA_Test_CI,Windows ARM64 QNN CI Pipeline,Windows GPU Doc Gen CI Pipeline |
|
Azure Pipelines successfully started running 4 pipeline(s). |
There was a problem hiding this comment.
Pull request overview
This PR enables the WebGPU DP4A implementation of MatMulNBits for 2-bit quantized weights when custom zero_points are provided, by extending the dequantization LUT logic in the DP4A WGSL shaders and removing the previous dispatch guard that blocked this path.
Changes:
- Extend the Q2 DP4A dequantization LUT from 256 entries to 1024 entries when
has_zero_pointsis true, and route dequantization through the zero-point-aware lookup. - Update both DP4A kernels (large-M and small-M variants) to load the expanded LUT and pass zero points into Q2 dequantization.
- Remove the C++ dispatch guard that prevented selecting the DP4A path for (nbits=2, has_zero_points=true), and update the in-code comment accordingly.
Reviewed changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated 1 comment.
| File | Description |
|---|---|
onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc |
Allows DP4A dispatch for 2-bit with zero points by removing the previous guard and updating the comment. |
onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_common.wgsl.template |
Adds a 1024-entry LUT (4×256) and a zero-point-aware Q2 dequantization function. |
onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul.wgsl.template |
Large-M DP4A kernel loads all 1024 LUT entries when needed and passes per-block zero point into Q2 dequantization. |
onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_small_m.wgsl.template |
Small-M DP4A kernel loads all 1024 LUT entries when needed, uses zero-point-aware Q2 dequantization, and fixes a LUT-load offset bug in the non-zero-point path. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 5 out of 5 changed files in this pull request and generated 2 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 7 out of 7 changed files in this pull request and generated no new comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Add zero point support to dp4a 2-bit dequantization in the WebGPU MatMulNBits kernel. Previously, the dp4a path for 2-bit quantization used a hardcoded 256-entry LUT assuming zero_point=2, and was blocked from running when custom zero points were provided.
dp4a_matmul_common.wgsl.template — Core LUT & dequantization function
Added a 1024-entry LUT (4 sections × 256 entries) when has_zero_points is true. Each section corresponds to a zero point value (0–3), pre-computing pack4xI8(value - zero_point) for every possible byte input.
Added a new DequantizedFrom2BitsTo8Bits(in: u32, zero: i32) overload that indexes the LUT as zero * 256 + byte_value.
Original 256-entry LUT and parameterless function preserved for the !has_zero_points path.
dp4a_matmul.wgsl.template — Large-M tiled kernel (workgroup=256)
loadSHMB for n_bits==2: reads zero point via mm_read_zero() and passes it to DequantizedFrom2BitsTo8Bits(b_value, zero) when has_zero_points.
LoadDequantizationTable: expanded to 4 calls (local_idx + 0/256/512/768) to load all 1024 entries when has_zero_points.
dp4a_matmul_small_m.wgsl.template — Small-M kernel (workgroup=128)
LoadDequantizationTable: expanded to 8 calls to load 1024 entries when has_zero_points.
DequantizedFrom2BitsTo8Bits calls pass zero when has_zero_points.
Bug fix: corrected off-by-one local_idx+127 → local_idx+128 in the non-zero-point path.
matmul_nbits.cc — Kernel dispatch logic
Removed the guard !(has_zero_points && nbits == 2) that previously blocked the dp4a path for 2-bit with custom zero points.
Updated comment to document the new 1024-entry LUT support.