Skip to content

Commit d172021

Browse files
author
morelos
committed
[ET-VK][Ops] torchao.dequantize_affine vulkan impl and shader and cleanup
# Changes * Implement `torchao.dequantize_affine` operator in Vulkan backend with comprehensive texture and buffer storage support * Add block-wise dequantization mode in `dequantize_texture.glsl` and `dequantize_buffer.glsl` shaders for configurable tensor block dequantization * Extend dequantization infrastructure in `Dequantize.cpp` to handle affine transformations with configurable block sizes and quantization parameters * Support integer-to-floating-point conversion with precise reconstruction of original values BE: Improved the documentation in the shader logic which is more detailed and clear # Motivation The existing Vulkan quantization infrastructure lacked support for the `torchao.dequantize_affine` operator, which is essential for completing the quantization-dequantization cycle in dynamic quantization workflows. The `dequantize_affine` operator provides flexible block-wise dequantization that reconstructs floating-point values from quantized integer blocks, enabling: * **Block-wise Dequantization**: Reconstructs floating-point values from configurable tensor blocks using separate scale and zero-point parameters, enabling precise recovery of original data distributions * **Affine Transformation**: Uses the formula `value = (qvalue - zero_point) * scale` for accurate integer-to-floating-point mapping * **TorchAO Integration**: Seamless compatibility with TorchAO quantization workflows and completes the quantization-dequantization round-trip # Operator Description The `dequantize_affine` operator converts n-bit integer tensor values back to floating-point representations using pre-computed quantization parameters (scale and zero_point) applied to configurable tensor blocks. Block-wise dequantization divides tensors into blocks and applies separate dequantization parameters to each block, allowing fine-grained reconstruction of the original floating-point precision. The dequantization formula is: `value = (qvalue - zero_point) * scale` **Storage Requirements**: Scale and zero_point tensors must use buffer storage with width-packed layout. Input/output tensors support both buffer and texture storage with standard axis mapping. Input tensors must be integer types (kByte, kChar, kInt). # Block-wise Dequantization Implementation Block-wise dequantization enables fine-grained reconstruction by dividing tensors into blocks and applying separate dequantization parameters to each block. The implementation uses the same key data structures computed in `Dequantize.cpp`: * **`block_size_vec`**: WHCN-ordered block dimensions converted from PyTorch NCHW layout (e.g., [3,3,2,1] for 3×3×2×1 blocks) * **`tensor_size_whcn`**: Input tensor dimensions converted to WHCN layout using `utils::make_whcn_ivec4()` * **`num_blocks_vec`**: Number of blocks per dimension calculated as `tensor_size_whcn / block_size_vec` * **`block_stride_vec`**: Pre-computed linear strides for block grid indexing `{1, #W, #W*#H, #W*#H*#C}` to enable efficient block ID calculation The block coordinate calculation uses: `bcoord = tidx / blockSize` where `tidx` is the tensor coordinate in WHCN layout, then the linear block ID is computed as: `block_id = bcoord.x * blockStride.x + bcoord.y * blockStride.y + bcoord.z * blockStride.z + bcoord.w * blockStride.w` # Shader Algorithm Overview ## Texture Storage Implementation (`dequantize_texture.glsl`) **Workgroup Configuration**: - **Global WG Size**: Default sizing based on texture dimensions - **Local WG Size**: Default with special handling for batch dimension dequantization (Z dimension set to 1 for proper workgroup dispatching when `global_workgroup_size[2] > 1`) **Block-wise Mode Algorithm**: The shader processes 3D texture positions where each position represents a texel containing 4 width-packed integer components. For each texel at position `pos`, it calculates a base tensor index `base_tidx = ivec4(pos.x * 4, pos.y, pos.z, 0)` to account for width-packing. For each of the 4 components in the texel, it computes the actual tensor coordinate: `tidx = ivec4(base_tidx.x + i, base_tidx.y, (foldedZ % C_total), (foldedZ / C_total))` where `foldedZ = pos.z` handles batch-channel folding in 4D tensors and `C_total = numBlocks.z * blockSize.z` represents the total channel dimension. The block coordinate is calculated using integer division: `bcoord = tidx / blockSize`, then the linear block ID uses pre-computed strides: `block_id = bcoord.x * blockStride.x + bcoord.y * blockStride.y + bcoord.z * blockStride.z + bcoord.w * blockStride.w`. Each integer component is dequantized using its corresponding block's parameters: `value = dequantize_val(qvalue, t_scale[block_id], t_zero_point[block_id])` where `dequantize_val()` applies the formula `(qvalue - zero_point) * scale`. The reconstructed floating-point values are written to the output texel with proper type handling for double precision outputs. ## Buffer Storage Implementation (`dequantize_buffer.glsl`) **Workgroup Configuration**: - **Global WG Size**: Default sizing based on buffer element count - **Local WG Size**: Default sizing without special constraints **Block-wise Mode Algorithm**: The shader processes linear buffer indices using `gl_GlobalInvocationID.x` as the output buffer index. It converts this to tensor coordinates using `bufi_to_tidx(out_bufi, t_out_strides, out_dim_order)` which handles the buffer-to-tensor index mapping with proper stride calculations. For each element, it computes the block coordinate directly: `bcoord = out_tidx / blockSize` where `out_tidx` is the 4D tensor coordinate in WHCN layout. The linear block ID calculation uses the same pre-computed stride approach: `block_id = bcoord.x * blockStride.x + bcoord.y * blockStride.y + bcoord.z * blockStride.z + bcoord.w * blockStride.w`. The quantized integer value is loaded using the corresponding input buffer index: `qvalue = t_in[in_bufi]` where `in_bufi = tidx_to_bufi(out_tidx, t_in_strides)`. Dequantization applies the block-specific parameters: `value = dequantize_val(qvalue, t_scale[block_id], t_zero_point[block_id])` to reconstruct the original floating-point value. **Future Improvements**: Dynamic workgroup sizing based on block dimensions Differential Revision: [D78435552](https://our.internmc.facebook.com/intern/diff/D78435552/) [ghstack-poisoned]
1 parent 00559b6 commit d172021

File tree

6 files changed

+736
-164
lines changed

6 files changed

+736
-164
lines changed

backends/vulkan/runtime/graph/ops/glsl/dequantize_buffer.glsl

Lines changed: 89 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,17 @@ $if MODE == "per_channel":
5353
int quant_min;
5454
int quant_max;
5555
};
56+
$if MODE == "block_wise":
57+
${layout_declare_tensor(B, "r", "t_scale", "float", "buffer")}
58+
${layout_declare_tensor(B, "r", "t_zero_point", "int", "buffer")}
59+
60+
layout(push_constant) uniform restrict Block {
61+
ivec4 blockSize; // bW, bH, bC, bN
62+
ivec4 numBlocks; // tW/bW, tH/bH, tC/bC, tN/bN
63+
ivec4 blockStride; // pre-computed linear strides for the block grid
64+
int quant_min;
65+
int quant_max;
66+
};
5667

5768
${layout_declare_ubo(B, "int", "out_numel")}
5869
${layout_declare_ubo(B, "ivec4", "t_in_sizes")}
@@ -71,68 +82,60 @@ const lowp ivec4 out_dim_order = unhash_dim_order(out_layout);
7182
const lowp ivec4 in_dim_order = unhash_dim_order(in_layout);
7283

7384
/*
74-
* DEQUANTIZATION SHADER (BUFFER STORAGE)
75-
*
76-
* This shader converts n-bit integer tensor values back to floating-point representations
77-
* using pre-computed quantization parameters (scale and zero_point). The dequantization
78-
* reconstructs the original floating-point values from their discrete integer representations
79-
* with minimal precision loss.
80-
*
81-
* ALGORITHM:
82-
* 1. Load quantized integer value from buffer
83-
* 2. Apply dequantization formula: value = (qvalue - zero_point) * scale
84-
* 3. Store reconstructed floating-point value to output buffer
85-
*
86-
* WORKGROUP CONFIGURATION:
87-
* - Per-Tensor Mode:
88-
* - Global WG Size: {num_elements, 1, 1} (one thread per tensor element)
89-
* - Local WG Size: Default (typically {64, 1, 1} or based on global WG size)
90-
* - Per-Token Mode:
91-
* - Global WG Size: {num_elements, 1, 1} (one thread per tensor element)
92-
* - Local WG Size: Default (typically {64, 1, 1} or based on global WG size)
93-
*
94-
* SUPPORTED CONFIGURATIONS:
95-
* - Buffer Storage: Uses linear buffer indexing with stride-based tensor access
96-
* - Per-Tensor: Supports any tensor layout through stride calculations and dimension ordering
97-
* - Per-Token: Supports only width packed tensors (packed_dim = 0) and standard axis mapping
98-
* - Scale/zero_point tensors: Must use buffer storage with width packing (packed_dim = 0)
99-
*
100-
* DEQUANTIZATION FORMULA VISUALIZATION:
101-
* For integer range [quant_min, quant_max] mapped back to [min_val, max_val]:
102-
*
103-
* Integer Domain: Floating Point Domain:
104-
* quant_min ──────────────► min_val
105-
* │ │
106-
* │ scale = (max_val - min_val) / (quant_max - quant_min)
107-
* │ zero_point = quant_min - round(min_val / scale)
108-
* │ │
109-
* quant_max ──────────────► max_val
110-
*
111-
* Dequantization Process:
112-
* Input: -103 (int8)
113-
* Step 1: qvalue - zero_point = -103 - (-128) = 25
114-
* Step 2: result * scale = 25 * 0.1 = 2.5
115-
* Output: 2.5 (float)
116-
*
117-
* PER-TENSOR DEQUANTIZATION:
118-
* - Single scale and zero_point values for entire tensor
119-
* - All elements use same dequantization parameters
120-
* - Parameters passed as push constants for efficiency
121-
* - Formula: value = (qvalue - zero_point) * scale
122-
*
123-
* PER-TOKEN DEQUANTIZATION:
124-
* - Separate scale and zero_point for each token
125-
* - Token = all elements except last dimension (e.g., for [B,S,H]: B*S tokens of H elements)
126-
* - Parameters stored in buffer arrays indexed by token_id
127-
* - Each thread calculates its token_id from tensor coordinates
128-
* - Formula: value = (qvalue - zero_point[token_id]) * scale[token_id]
129-
*
130-
* Token ID calculation for element at tensor index (w, z, y, x):
131-
* - 4D tensor: token_id = w * (sizes.z * sizes.y) + z * sizes.y + y
132-
* - 3D tensor: token_id = z * sizes.y + y
133-
* - 2D tensor: token_id = y
134-
* - 1D tensor: token_id = 0
135-
*/
85+
Dequantization Shader (Buffer Storage)
86+
This shader converts n-bit integer tensor values back to floating-point representations
87+
using pre-computed quantization parameters (scale and zero_point). The dequantization
88+
reconstructs the original floating-point values from their discrete integer representations
89+
with minimal precision loss.
90+
91+
Important Considerations:
92+
(+) All input tensors are assumed to be WIDTH_PACKED (i.e., contiguous in the last dimension)
93+
(+) The axis map layout is assumed to be a standard layout for scales and zero_points
94+
(++) The scale and zero_point tensors must be implemented as buffers
95+
96+
Workgroup Configuration:
97+
- dequantize_per_tensor
98+
This mode reverses the uniform quantization applied across the entire tensor by using the
99+
single scale and zero_point values to convert quantized integer values back to their original
100+
floating-point representation.
101+
102+
(*) global_wg_size: default
103+
(*) local_wg_size: default
104+
105+
- dequantize_per_token
106+
This mode reverses the quantization applied individually to each token (or element) in the
107+
input by using separate scale and zero_point values for each token. For a tensor of shape
108+
[B, S, H], it applies the inverse transformation token-wise across the B*S tokens, converting
109+
quantized values back to their original floating-point representation for each group of H
110+
elements independently.
111+
112+
(*) global_wg_size: default
113+
(*) local_wg_size: default
114+
115+
- dequantize_per_channel
116+
This mode reverses the quantization applied separately to each channel of the input tensor
117+
by using distinct scale and zero_point values for each channel. For a tensor of shape
118+
[B, C, H, W] with axis = 1, it applies the inverse transformation channel-wise across the C
119+
channels, converting quantized values back to their original floating-point representation
120+
independently for each channel.
121+
122+
(*) global_wg_size: default
123+
(*) local_wg_size: default
124+
125+
- dequantize_block_wise
126+
This mode reverses the block-wise quantization applied to groups of elements by using separate
127+
scale and zero_point values for each block. Equivalent to dequantize_affine, it applies the
128+
inverse affine transformation per block to convert quantized values back to their original
129+
floating-point representation. For example, if the tensor shape is [6, 9, 4] and
130+
blockSize = [3, 3, 2], the tensor is divided into 12 blocks, each containing 18 elements,
131+
and dequantization is performed independently on each block.
132+
133+
(*) global_wg_size: default
134+
(*) local_wg_size: default
135+
136+
Dequantization Formula:
137+
value = (qvalue - zero_point) * scale
138+
*/
136139

137140
#ifdef per_tensor
138141

@@ -187,7 +190,7 @@ void dequantize_per_token() {
187190
t_out[out_bufi] = value;
188191
}
189192

190-
#else // per_channel
193+
#elif defined(per_channel)
191194

192195
void dequantize_per_channel() {
193196
const int out_bufi = int(gl_GlobalInvocationID.x);
@@ -226,6 +229,29 @@ void dequantize_per_channel() {
226229
t_out[out_bufi] = value;
227230
}
228231

232+
#else // block_wise
233+
234+
void dequantize_block_wise() {
235+
const int out_bufi = int(gl_GlobalInvocationID.x);
236+
237+
if (out_bufi >= out_numel) {
238+
return;
239+
}
240+
241+
const ivec4 out_tidx = bufi_to_tidx(out_bufi, t_out_strides, out_dim_order);
242+
const int in_bufi = tidx_to_bufi(out_tidx, t_in_strides);
243+
244+
IN_T qvalue = t_in[in_bufi];
245+
246+
const ivec4 bcoord = out_tidx / blockSize;
247+
248+
const int block_id = bcoord.x * blockStride.x + bcoord.y * blockStride.y + bcoord.z * blockStride.z + bcoord.w * blockStride.w;
249+
250+
const OUT_T value = dequantize_val(qvalue, t_scale[block_id], t_zero_point[block_id]);
251+
252+
t_out[out_bufi] = value;
253+
}
254+
229255
#endif
230256

231257
void main() {

backends/vulkan/runtime/graph/ops/glsl/dequantize_buffer.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,5 @@ dequantize_buffer:
1919
MODE: per_token
2020
- NAME: dequantize_per_channel_buffer
2121
MODE: per_channel
22+
- NAME: dequantize_block_wise_buffer
23+
MODE: block_wise

backends/vulkan/runtime/graph/ops/glsl/dequantize_texture.glsl

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,17 @@ $if MODE == "per_channel":
5656
int quant_min;
5757
int quant_max;
5858
};
59+
$if MODE == "block_wise":
60+
${layout_declare_tensor(B, "r", "t_scale", "float", "buffer")}
61+
${layout_declare_tensor(B, "r", "t_zero_point", "int", "buffer")}
62+
63+
layout(push_constant) uniform restrict Block {
64+
ivec4 blockSize; // bW, bH, bC, bN
65+
ivec4 numBlocks; // tW/bW, tH/bH, tC/bC, tN/bN
66+
ivec4 blockStride; // pre-computed linear strides for the block grid
67+
int quant_min;
68+
int quant_max;
69+
};
5970

6071
${layout_declare_ubo(B, "ivec3", "t_in_limits")}
6172
${layout_declare_ubo(B, "ivec3", "t_out_limits")}
@@ -201,7 +212,7 @@ void dequantize_per_token() {
201212
write_texel(t_out, pos, outtex);
202213
}
203214

204-
#else // per_channel
215+
#elif defined(per_channel)
205216

206217
void dequantize_per_channel() {
207218
const ivec3 pos = ivec3(gl_GlobalInvocationID);
@@ -292,6 +303,39 @@ void dequantize_per_channel() {
292303
write_texel(t_out, pos, outtex);
293304
}
294305

306+
#else // block_wise
307+
308+
void dequantize_block_wise() {
309+
const ivec3 pos = ivec3(gl_GlobalInvocationID);
310+
311+
if (any(greaterThanEqual(pos, t_in_limits)))
312+
return;
313+
314+
IVEC4_T intex = load_texel(t_in, pos);
315+
FVEC4_T outtex;
316+
317+
ivec4 base_tidx = ivec4(pos.x * 4, pos.y, pos.z, 0);
318+
int foldedZ = pos.z;
319+
320+
int C_total = numBlocks.z * blockSize.z;
321+
322+
[[unroll]] for (int i = 0; i < 4; ++i) {
323+
ivec4 tidx = ivec4(base_tidx.x + i, base_tidx.y, (foldedZ % C_total), (foldedZ / C_total));
324+
325+
ivec4 bcoord = tidx / blockSize;
326+
int block_id = bcoord.x * blockStride.x + bcoord.y * blockStride.y + bcoord.z * blockStride.z + bcoord.w * blockStride.w;
327+
328+
IN_T qvalue = IN_T(intex[i]);
329+
OUT_T value = dequantize_val(qvalue, t_scale[block_id], t_zero_point[block_id]);
330+
$if OUT_DTYPE == "double":
331+
outtex[i] = float(value);
332+
$else:
333+
outtex[i] = value;
334+
}
335+
336+
write_texel(t_out, pos, outtex);
337+
}
338+
295339
#endif
296340

297341
void main() {

backends/vulkan/runtime/graph/ops/glsl/dequantize_texture.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,5 @@ dequantize_texture:
1919
MODE: per_token
2020
- NAME: dequantize_per_channel_texture3d
2121
MODE: per_channel
22+
- NAME: dequantize_block_wise_texture3d
23+
MODE: block_wise

0 commit comments

Comments
 (0)