Skip to content

Commit 7667af4

Browse files
author
morelos
committed
[ET-VK][Ops] torchao.quantize_affine vulkan impl and shader and cleanup
Pull Request resolved: #12575 # Changes * Implement `torchao.quantize_affine` operator in Vulkan backend with comprehensive texture and buffer storage support * Add block-wise quantization mode in `quantize_texture.glsl` and `quantize_buffer.glsl` shaders for configurable tensor block quantization * Introduce comprehensive test suite in `affine_test.cpp` with multi-dimensional tensor validation and reference implementation * Extend quantization infrastructure in `Quantize.cpp` to handle affine transformations with configurable block sizes and quantization parameters BE: Improved the documentation in the shader logic which is more detailed and clear NOTE: I delegated the quantize_affine and future affine operators through a new custom test file denoted as `affine_test.cpp` as the other quantization testing framework was getting a little large, and it makes more sense to separate the namespace between torchao and quantized_decomposed. I believe the _decomposed namespace is getting phased out in favor of this affine operator so deprecation will be easier in the future. # Motivation The existing Vulkan quantization infrastructure lacked support for the `torchao.quantize_affine` operator, which is essential for enabling dynamic quantization efficiently. The `quantize_affine` operator provides flexible block-wise quantization that allows different scale and zero-point values for tensor blocks, enabling: * **Block-wise Quantization**: Applies quantization parameters to configurable tensor blocks rather than entire tensors, improving quantization accuracy for heterogeneous data distributions * **Affine Transformation**: Uses the formula `qvalue = clamp(round(value / scale) + zero_point, quant_min, quant_max)` for precise floating-point to integer mapping # Operator Description The `quantize_affine` operator converts floating-point tensor values to n-bit integer representations using pre-computed quantization parameters (scale and zero_point) applied to configurable tensor blocks. Block-wise quantization divides tensors into blocks and applies separate quantization parameters to each block, allowing fine-grained control over quantization precision. The quantization formula is: `qvalue = clamp(round(value / scale) + zero_point, quant_min, quant_max)` **Storage Requirements**: Scale and zero_point tensors must use buffer storage with width-packed layout. Input/output tensors support both buffer and texture storage with standard axis mapping. # Block-wise Quantization Implementation Block-wise quantization enables fine-grained quantization by dividing tensors into blocks and applying separate quantization parameters to each block. The implementation uses several key data structures computed in `Quantize.cpp`: * **`block_size_vec`**: WHCN-ordered block dimensions converted from PyTorch NCHW layout (e.g., [3,3,2,1] for 3×3×2×1 blocks) * **`tensor_size_whcn`**: Input tensor dimensions converted to WHCN layout using `utils::make_whcn_ivec4()` * **`num_blocks_vec`**: Number of blocks per dimension calculated as `tensor_size_whcn / block_size_vec` * **`block_stride_vec`**: Pre-computed linear strides for block grid indexing `{1, #W, #W*#H, #W*#H*#C}` to enable efficient block ID calculation The block coordinate calculation uses: `bcoord = tidx / blockSize` where `tidx` is the tensor coordinate in WHCN layout, then the linear block ID is computed as: `block_id = bcoord.x * blockStride.x + bcoord.y * blockStride.y + bcoord.z * blockStride.z + bcoord.w * blockStride.w` # Shader Algorithm Overview ## Texture Storage Implementation (`quantize_texture.glsl`) **Workgroup Configuration**: - **Global WG Size**: Default sizing based on texture dimensions - **Local WG Size**: Default with special handling for batch dimension quantization (Z dimension set to 1 for proper workgroup dispatching when `global_workgroup_size[2] > 1`) **Block-wise Mode Algorithm**: The shader processes 3D texture positions where each position represents a texel containing 4 width-packed components. For each texel at position `pos`, it calculates a base tensor index `base_tidx = ivec4(pos.x * 4, pos.y, pos.z, 0)` to account for width-packing. For each of the 4 components in the texel, it computes the actual tensor coordinate: `tidx = ivec4(base_tidx.x + i, base_tidx.y, (foldedZ % C_total), (foldedZ / C_total))` where `foldedZ = pos.z` handles batch-channel folding in 4D tensors and `C_total = numBlocks.z * blockSize.z` represents the total channel dimension. The block coordinate is calculated using integer division: `bcoord = tidx / blockSize`, then the linear block ID uses pre-computed strides: `block_id = bcoord.x * blockStride.x + bcoord.y * blockStride.y + bcoord.z * blockStride.z + bcoord.w * blockStride.w`. Each component is quantized using its corresponding block's parameters: `qvalue = quantize_val(value, t_scale[block_id], t_zero_point[block_id])` and written to the output texel. ## Buffer Storage Implementation (`quantize_buffer.glsl`) **Workgroup Configuration**: - **Global WG Size**: Default sizing based on buffer element count - **Local WG Size**: Default sizing without special constraints **Block-wise Mode Algorithm**: The shader processes linear buffer indices using `gl_GlobalInvocationID.x` as the output buffer index. It converts this to tensor coordinates using `bufi_to_tidx(out_bufi, t_out_strides, out_dim_order)` which handles the buffer-to-tensor index mapping with proper stride calculations. For each element, it computes the block coordinate directly: `bcoord = out_tidx / blockSize` where `out_tidx` is the 4D tensor coordinate in WHCN layout. The linear block ID calculation uses the same pre-computed stride approach: `block_id = bcoord.x * blockStride.x + bcoord.y * blockStride.y + bcoord.z * blockStride.z + bcoord.w * blockStride.w`. The element value is loaded using the corresponding input buffer index: `value = t_in[in_bufi]` where `in_bufi = tidx_to_bufi(out_tidx, t_in_strides)`. Quantization applies the block-specific parameters: `qvalue = quantize_val(value, t_scale[block_id], t_zero_point[block_id])`. **Future Improvements**: Dynamic workgroup sizing based on block dimensions, there is likely a better method to making it better than what it is currently. ghstack-source-id: 299473617 @exported-using-ghexport Differential Revision: [D78302195](https://our.internmc.facebook.com/intern/diff/D78302195/)
1 parent 272393c commit 7667af4

File tree

7 files changed

+798
-149
lines changed

7 files changed

+798
-149
lines changed

backends/vulkan/runtime/graph/ops/glsl/quantize_buffer.glsl

Lines changed: 83 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,17 @@ $if MODE == "per_channel":
5353
int quant_min;
5454
int quant_max;
5555
};
56+
$if MODE == "block_wise":
57+
${layout_declare_tensor(B, "r", "t_scale", "float", "buffer")}
58+
${layout_declare_tensor(B, "r", "t_zero_point", "int", "buffer")}
59+
60+
layout(push_constant) uniform restrict Block {
61+
ivec4 blockSize; // bW, bH, bC, bN
62+
ivec4 numBlocks; // tW/bW, tH/bH, tC/bC, tN/bN
63+
ivec4 blockStride; // pre-computed linear strides for the block grid
64+
int quant_min;
65+
int quant_max;
66+
};
5667

5768
${layout_declare_ubo(B, "int", "out_numel")}
5869
${layout_declare_ubo(B, "ivec4", "t_in_sizes")}
@@ -71,64 +82,54 @@ const lowp ivec4 out_dim_order = unhash_dim_order(out_layout);
7182
const lowp ivec4 in_dim_order = unhash_dim_order(in_layout);
7283

7384
/*
74-
* QUANTIZATION SHADER (BUFFER STORAGE)
75-
*
76-
* This shader converts floating-point tensor values to n-bit integer representations
77-
* using pre-computed quantization parameters (scale and zero_point). The quantization
78-
* maps floating-point values to a discrete integer range while preserving the
79-
* original data distribution as much as possible.
80-
*
81-
* ALGORITHM:
82-
* 1. Load floating-point input value from buffer
83-
* 2. Apply quantization formula: qvalue = round(value / scale) + zero_point
84-
* 3. Clamp result to [quant_min, quant_max] range
85-
* 4. Store quantized integer value to output buffer
86-
*
87-
* WORKGROUP CONFIGURATION:
88-
* - Per-Tensor Mode:
89-
* - Global WG Size: {num_elements, 1, 1} (one thread per tensor element)
90-
* - Local WG Size: Default (typically {64, 1, 1} or based on global WG size)
91-
* - Per-Token Mode:
92-
* - Global WG Size: {num_elements, 1, 1} (one thread per tensor element)
93-
* - Local WG Size: Default (typically {64, 1, 1} or based on global WG size)
94-
*
95-
* SUPPORTED CONFIGURATIONS:
96-
* - Per-Tensor Config: Uses linear buffer indexing with stride-based tensor access
97-
* - and supports any tensor layout through stride calculations and dimension ordering
98-
* - Per-Token Config: Assumes width-packed layout (packed_dim = 0)
99-
* - since that is how token index is calculated
100-
*
101-
* QUANTIZATION FORMULA VISUALIZATION:
102-
* For input range [min_val, max_val] mapped to integer range [quant_min, quant_max]:
103-
*
104-
* Floating Point Domain: Integer Domain:
105-
* min_val ────────────────► quant_min
106-
* │ │
107-
* │ scale = (max_val - min_val) / (quant_max - quant_min)
108-
* │ zero_point = quant_min - round(min_val / scale)
109-
* │ │
110-
* max_val ────────────────► quant_max
111-
*
112-
* Quantization Process:
113-
* Input: 2.5 (float)
114-
* Step 1: value / scale = 2.5 / 0.1 = 25.0
115-
* Step 2: round(25.0) + zero_point = 25 + (-128) = -103
116-
* Step 3: clamp(-103, -128, 127) = -103
117-
* Output: -103 (int8)
118-
*
119-
* PER-TENSOR QUANTIZATION:
120-
* - Single scale and zero_point values for entire tensor
121-
* - All elements use same quantization parameters
122-
* - Parameters passed as push constants for efficiency
123-
* - Formula: qvalue = clamp(round(value / scale) + zero_point, quant_min, quant_max)
124-
*
125-
* PER-TOKEN QUANTIZATION:
126-
* - Separate scale and zero_point for each token
127-
* - Token = all elements except last dimension (e.g., for [B,S,H]: B*S tokens of H elements)
128-
* - Parameters stored in buffer arrays indexed by token_id
129-
* - Each thread calculates its token_id from tensor coordinates
130-
* - Formula: qvalue = clamp(round(value / scale[token_id]) + zero_point[token_id], quant_min, quant_max)
131-
*/
85+
Quantization Shader (Buffer Storage)
86+
This shader converts floating-point tensor values to n-bit integer representations
87+
using pre-computed quantization parameters (scale and zero_point). The quantization
88+
maps floating-point values to a discrete integer range while preserving the original
89+
data distribution as much as possible.
90+
91+
Important Considerations:
92+
(+) All input tensors are assumed to be WIDTH_PACKED (i.e., contiguous in the last dimension)
93+
(+) The axis map layout is assumed to be a standard layout for scales and zero_points
94+
(++) The scale and zero_point tensors must be implemented as buffers
95+
96+
Workgroup Configuration:
97+
- quantize_per_tensor
98+
This mode applies uniform quantization across the entire tensor using a single scale
99+
and zero_point value.
100+
101+
(*) global_wg_size: default
102+
(*) local_wg_size: default
103+
104+
- quantize_per_token
105+
This mode applies quantization individually to each token (or element) in the input,
106+
using separate scale and zero_point values for each token. For instance if we have
107+
a tensor of shape [B, S, H] then we have B*S tokens (and s+zp pairs) of H elements each.
108+
109+
(*) global_wg_size: default
110+
(*) local_wg_size: default
111+
112+
- quantize_per_channel
113+
This mode applies quantization separately to each channel of the input tensor, using
114+
distinct scale and zero_point values for each channel. For example, if the tensor shape
115+
is [B, C, H, W] and axis = 1, quantization parameters are computed per channel C, allowing
116+
each channel to be quantized independently.
117+
118+
(*) global_wg_size: default
119+
(*) local_wg_size: default
120+
121+
- quantize_block_wise
122+
This mode applies quantization in blocks or groups of elements, allowing different scale
123+
and zero_point values for each block. It is equivalent to quantize_affine, where quantization
124+
parameters are affine transformations applied per block. For example, if the tensor shape
125+
is [6, 9, 4] and blockSize = [3, 3, 2], then we have 12 blocks each with 18 elements.
126+
127+
(*) global_wg_size: default
128+
(*) local_wg_size: default
129+
130+
Quantization Formula:
131+
qvalue = clamp(round(value / scale) + zero_point, quant_min, quant_max).
132+
*/
132133

133134
#ifdef per_tensor
134135

@@ -183,7 +184,7 @@ void quantize_per_token() {
183184
t_out[out_bufi] = qvalue;
184185
}
185186

186-
#else // per_channel
187+
#elif defined(per_channel)
187188

188189
void quantize_per_channel() {
189190
const int out_bufi = int(gl_GlobalInvocationID.x);
@@ -222,6 +223,29 @@ void quantize_per_channel() {
222223
t_out[out_bufi] = qvalue;
223224
}
224225

226+
#else // block_wise
227+
228+
void quantize_block_wise() {
229+
const int out_bufi = int(gl_GlobalInvocationID.x);
230+
231+
if (out_bufi >= out_numel) {
232+
return;
233+
}
234+
235+
const ivec4 out_tidx = bufi_to_tidx(out_bufi, t_out_strides, out_dim_order);
236+
const int in_bufi = tidx_to_bufi(out_tidx, t_in_strides);
237+
238+
IN_T value = t_in[in_bufi];
239+
240+
const ivec4 bcoord = out_tidx / blockSize;
241+
242+
const int block_id = bcoord.x * blockStride.x + bcoord.y * blockStride.y + bcoord.z * blockStride.z + bcoord.w * blockStride.w;
243+
244+
const OUT_T qvalue = quantize_val(value, t_scale[block_id], t_zero_point[block_id]);
245+
246+
t_out[out_bufi] = qvalue;
247+
}
248+
225249
#endif
226250

227251
void main() {

backends/vulkan/runtime/graph/ops/glsl/quantize_buffer.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,5 @@ quantize_buffer:
1919
MODE: per_token
2020
- NAME: quantize_per_channel_buffer
2121
MODE: per_channel
22+
- NAME: quantize_block_wise_buffer
23+
MODE: block_wise

backends/vulkan/runtime/graph/ops/glsl/quantize_texture.glsl

Lines changed: 94 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,17 @@ $if MODE == "per_channel":
5858
int quant_min;
5959
int quant_max;
6060
};
61+
$if MODE == "block_wise":
62+
${layout_declare_tensor(B, "r", "t_scale", "float", "buffer")}
63+
${layout_declare_tensor(B, "r", "t_zero_point", "int", "buffer")}
64+
65+
layout(push_constant) uniform restrict BlockPC {
66+
ivec4 blockSize; // WHCN
67+
ivec4 numBlocks; // (#W,#H,#C,#N)
68+
ivec4 blockStride; // {1, #W, #W * #H, #W * #H * #C}
69+
int quant_min;
70+
int quant_max;
71+
};
6172

6273
${layout_declare_ubo(B, "ivec3", "t_in_limits")}
6374
${layout_declare_ubo(B, "ivec3", "t_out_limits")}
@@ -70,68 +81,58 @@ ${layout_declare_spec_const(C, "int", "in_layout", "DEFAULT_LAYOUT")}
7081
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
7182

7283
/*
73-
* QUANTIZATION SHADER (TEXTURE STORAGE)
74-
*
75-
* This shader converts floating-point tensor values to n-bit integer representations
76-
* using pre-computed quantization parameters (scale and zero_point). The quantization
77-
* maps floating-point values to a discrete integer range while preserving the
78-
* original data distribution as much as possible.
79-
*
80-
* ALGORITHM:
81-
* 1. Load floating-point texel (4 values) from 3D texture
82-
* 2. Apply quantization formula to each component: qvalue = round(value / scale) + zero_point
83-
* 3. Clamp each result to [quant_min, quant_max] range
84-
* 4. Store quantized integer texel to output texture
85-
*
86-
* WORKGROUP CONFIGURATION:
87-
* - Per-Tensor Mode:
88-
* - Global WG Size: {W, H, C/4} for input size (W, H, C) with width-packing
89-
* - Local WG Size: Default (typically {8, 8, 1} or based on global WG size)
90-
* - Per-Token Mode:
91-
* - Global WG Size: {W, H, C/4} for input size (W, H, C) with width-packing
92-
* - Local WG Size: Default (typically {8, 8, 1} or based on global WG size)
93-
*
94-
* SUPPORTED CONFIGURATIONS:
95-
* - Texture Storage: Uses 3D texture indexing with texel-based processing
96-
* - Assumes width-packed layout (packed_dim = 0) in current implementation
97-
* - Handles texel padding for non-multiple-of-4 tensor dimensions
98-
* - For per-token mode: scale/zero_point tensors must use buffer storage
99-
*
100-
* QUANTIZATION FORMULA VISUALIZATION:
101-
* For input range [min_val, max_val] mapped to integer range [quant_min, quant_max]:
102-
*
103-
* Floating Point Domain: Integer Domain:
104-
* min_val ────────────────► quant_min
105-
* │ │
106-
* │ scale = (max_val - min_val) / (quant_max - quant_min)
107-
* │ zero_point = quant_min - round(min_val / scale)
108-
* │ │
109-
* max_val ────────────────► quant_max
110-
*
111-
* Texel Quantization Process:
112-
* Input Texel: [2.5, -1.0, 0.5, 3.2] (float4)
113-
* Per-component quantization with scale=0.1, zero_point=-128:
114-
* Component 0: round(2.5 / 0.1) + (-128) = 25 + (-128) = -103
115-
* Component 1: round(-1.0 / 0.1) + (-128) = -10 + (-128) = -138 → clamp to -128
116-
* Component 2: round(0.5 / 0.1) + (-128) = 5 + (-128) = -123
117-
* Component 3: round(3.2 / 0.1) + (-128) = 32 + (-128) = -96
118-
* Output Texel: [-103, -128, -123, -96] (int4)
119-
*
120-
* PER-TENSOR QUANTIZATION:
121-
* - Single scale and zero_point values for entire tensor
122-
* - All texel components use same quantization parameters
123-
* - Parameters passed as push constants for efficiency
124-
* - Each thread processes one texel (4 elements) independently
125-
* - Formula: qvalue[i] = clamp(round(value[i] / scale) + zero_point, quant_min, quant_max)
126-
*
127-
* PER-TOKEN QUANTIZATION:
128-
* - Separate scale and zero_point for each token
129-
* - Token = all elements except last dimension (e.g., for [B,S,H]: B*S tokens of H elements)
130-
* - Parameters stored in buffer arrays indexed by token_id
131-
* - Each thread calculates token_id from its 3D texture position
132-
* - Scale/zero_point buffers accessed directly (not as textures)
133-
* - Formula: qvalue[i] = clamp(round(value[i] / scale[token_id]) + zero_point[token_id], quant_min, quant_max)
134-
*/
84+
Quantization Shader (Texture Storage)
85+
This shader converts floating-point tensor values to n-bit integer representations
86+
using pre-computed quantization parameters (scale and zero_point). The quantization
87+
maps floating-point values to a discrete integer range while preserving the original
88+
data distribution as much as possible.
89+
90+
Important Considerations:
91+
(+) All input tensors are assumed to be WIDTH_PACKED (i.e., contiguous in the last dimension)
92+
(+) The axis map layout is assumed to be a standard layout for scales and zero_points
93+
(++) The scale and zero_point tensors must be implemented as buffers
94+
95+
Workgroup Configuration:
96+
- quantize_per_tensor
97+
This mode applies uniform quantization across the entire tensor using a single scale
98+
and zero_point value.
99+
100+
(*) global_wg_size: default
101+
(*) local_wg_size: default
102+
103+
- quantize_per_token
104+
This mode applies quantization individually to each token (or element) in the input,
105+
using separate scale and zero_point values for each token. For instance if we have
106+
a tensor of shape [B, S, H] then we have B*S tokens (and s+zp pairs) of H elements each.
107+
108+
(*) global_wg_size: default
109+
(*) local_wg_size: default
110+
111+
- quantize_per_channel
112+
This mode applies quantization separately to each channel of the input tensor, using
113+
distinct scale and zero_point values for each channel. For example, if the tensor shape
114+
is [B, C, H, W] and axis = 1, quantization parameters are computed per channel C, allowing
115+
each channel to be quantized independently.
116+
117+
(*) global_wg_size: default
118+
(*) local_wg_size: Default with special handling for batch dimension. When quantizing along
119+
the batch axis, Z dimension is set to 1 to ensure correct workgroup dispatching. Otherwise,
120+
uses standard workgroup size derived from global workgroup dimensions.
121+
122+
- quantize_block_wise
123+
This mode applies quantization in blocks or groups of elements, allowing different scale
124+
and zero_point values for each block. It is equivalent to quantize_affine, where quantization
125+
parameters are affine transformations applied per block. For example, if the tensor shape
126+
is [6, 9, 4] and blockSize = [3, 3, 2], then we have 12 blocks each with 18 elements.
127+
128+
(*) global_wg_size: default
129+
(*) local_wg_size: Default with special handling for batch dimension. When quantizing along
130+
the batch axis, Z dimension is set to 1 to ensure correct workgroup dispatching. Otherwise,
131+
uses standard workgroup size derived from global workgroup dimensions.
132+
133+
Quantization Formula:
134+
qvalue = clamp(round(value / scale) + zero_point, quant_min, quant_max).
135+
*/
135136

136137
#ifdef per_tensor
137138

@@ -192,7 +193,7 @@ void quantize_per_token() {
192193
write_texel(t_out, pos, outtex);
193194
}
194195

195-
#else // per_channel
196+
#elif defined(per_channel)
196197

197198
void quantize_per_channel() {
198199
const ivec3 pos = ivec3(gl_GlobalInvocationID);
@@ -270,6 +271,36 @@ void quantize_per_channel() {
270271
write_texel(t_out, pos, outtex);
271272
}
272273

274+
#else // block_wise
275+
276+
void quantize_block_wise() {
277+
const ivec3 pos = ivec3(gl_GlobalInvocationID);
278+
279+
if (any(greaterThanEqual(pos, t_in_limits)))
280+
return;
281+
282+
FVEC4_T intex = load_texel(t_in, pos);
283+
IVEC4_T outtex;
284+
285+
ivec4 base_tidx = ivec4(pos.x * 4, pos.y, pos.z, 0);
286+
int foldedZ = pos.z;
287+
288+
int C_total = numBlocks.z * blockSize.z;
289+
290+
[[unroll]] for (int i = 0; i < 4; ++i) {
291+
ivec4 tidx = ivec4(base_tidx.x + i, base_tidx.y, (foldedZ % C_total), (foldedZ / C_total));
292+
293+
ivec4 bcoord = tidx / blockSize;
294+
int block_id = bcoord.x * blockStride.x + bcoord.y * blockStride.y + bcoord.z * blockStride.z + bcoord.w * blockStride.w;
295+
296+
IN_T value = IN_T(intex[i]);
297+
OUT_T qvalue = quantize_val(value, t_scale[block_id], t_zero_point[block_id]);
298+
outtex[i] = qvalue;
299+
}
300+
301+
write_texel(t_out, pos, outtex);
302+
}
303+
273304
#endif
274305

275306
void main() {

backends/vulkan/runtime/graph/ops/glsl/quantize_texture.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,5 @@ quantize_texture:
1919
MODE: per_token
2020
- NAME: quantize_per_channel_texture3d
2121
MODE: per_channel
22+
- NAME: quantize_block_wise_texture3d
23+
MODE: block_wise

0 commit comments

Comments
 (0)