Skip to content

Commit a6580b5

Browse files
authored
[ET-VK][Ops] quantize_per_channel shaders and impl (pytorch#12433)
# Context We need to enable the core logic for quantize_per_channel in the vulkan shader. This implements the shader itself and its cpp header. TODO: add more of a description regarding the operator # Changes This creates an extension of the existing files for quantize_per_channel. Differential Revision: [D77746140](https://our.internmc.facebook.com/intern/diff/D77746140/)
1 parent aedb0fe commit a6580b5

File tree

6 files changed

+995
-7
lines changed

6 files changed

+995
-7
lines changed

backends/vulkan/runtime/graph/ops/glsl/quantize_buffer.glsl

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,16 @@ $if MODE == "per_token":
4242
int quant_min;
4343
int quant_max;
4444
};
45+
$if MODE == "per_channel":
46+
${layout_declare_tensor(B, "r", "t_scale", "float", "buffer")}
47+
${layout_declare_tensor(B, "r", "t_zero_point", "int", "buffer")}
48+
49+
layout(push_constant) uniform restrict Block {
50+
int axis;
51+
int num_channels;
52+
int quant_min;
53+
int quant_max;
54+
};
4555

4656
${layout_declare_ubo(B, "int", "out_numel")}
4757
${layout_declare_ubo(B, "ivec4", "t_in_sizes")}
@@ -137,7 +147,7 @@ void quantize_per_tensor() {
137147
t_out[out_bufi] = qvalue;
138148
}
139149

140-
#else
150+
#elif defined(per_token)
141151

142152
void quantize_per_token() {
143153
const int out_bufi = int(gl_GlobalInvocationID.x);
@@ -172,6 +182,45 @@ void quantize_per_token() {
172182
t_out[out_bufi] = qvalue;
173183
}
174184

185+
#else // per_channel
186+
187+
void quantize_per_channel() {
188+
const int out_bufi = int(gl_GlobalInvocationID.x);
189+
190+
if (out_bufi >= out_numel) {
191+
return;
192+
}
193+
194+
const ivec4 out_tidx = bufi_to_tidx(out_bufi, t_out_strides, out_dim_order);
195+
const int in_bufi = tidx_to_bufi(out_tidx, t_in_strides);
196+
197+
IN_T value = t_in[in_bufi];
198+
199+
// Calculate channel index based on the quantization axis (already converted to WHCN)
200+
// The axis parameter is now in WHCN coordinate system:
201+
// axis 0 -> W dimension (tidx.x)
202+
// axis 1 -> H dimension (tidx.y)
203+
// axis 2 -> C dimension (tidx.z)
204+
// axis 3 -> N dimension (tidx.w)
205+
int channel_idx = 0;
206+
207+
if (axis == 0) {
208+
channel_idx = out_tidx.x;
209+
} else if (axis == 1) {
210+
channel_idx = out_tidx.y;
211+
} else if (axis == 2) {
212+
channel_idx = out_tidx.z;
213+
} else if (axis == 3) {
214+
channel_idx = out_tidx.w;
215+
}
216+
217+
channel_idx = min(channel_idx, num_channels - 1);
218+
219+
OUT_T qvalue = quantize_val(value, t_scale[channel_idx], t_zero_point[channel_idx]);
220+
221+
t_out[out_bufi] = qvalue;
222+
}
223+
175224
#endif
176225

177226
void main() {

backends/vulkan/runtime/graph/ops/glsl/quantize_buffer.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,3 +17,5 @@ quantize_buffer:
1717
MODE: per_tensor
1818
- NAME: quantize_per_token_buffer
1919
MODE: per_token
20+
- NAME: quantize_per_channel_buffer
21+
MODE: per_channel

backends/vulkan/runtime/graph/ops/glsl/quantize_texture.glsl

Lines changed: 94 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ ${define_required_extensions(OUT_DTYPE)}
2626

2727
layout(std430) buffer;
2828

29+
#include "indexing_utils.h"
30+
2931
${layout_declare_tensor(B, "w", "t_out", OUT_DTYPE, "texture3d")}
3032
${layout_declare_tensor(B, "r", "t_in", IN_DTYPE, "texture3d")}
3133

@@ -45,11 +47,23 @@ $if MODE == "per_token":
4547
int quant_min;
4648
int quant_max;
4749
};
50+
$if MODE == "per_channel":
51+
${layout_declare_tensor(B, "r", "t_scale", "float", "buffer")}
52+
${layout_declare_tensor(B, "r", "t_zero_point", "int", "buffer")}
53+
54+
layout(push_constant) uniform restrict Block {
55+
int axis;
56+
int num_channels;
57+
int quant_min;
58+
int quant_max;
59+
};
4860

4961
${layout_declare_ubo(B, "ivec3", "t_in_limits")}
5062
${layout_declare_ubo(B, "ivec3", "t_out_limits")}
5163

52-
#include "indexing_utils.h"
64+
${layout_declare_spec_const(C, "int", "out_layout", "DEFAULT_LAYOUT")}
65+
${layout_declare_spec_const(C, "int", "in_layout", "DEFAULT_LAYOUT")}
66+
5367
#include "quantize.glslh"
5468

5569
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
@@ -138,7 +152,7 @@ void quantize_per_tensor() {
138152
write_texel(t_out, pos, outtex);
139153
}
140154

141-
#else
155+
#elif defined(per_token)
142156

143157
void quantize_per_token() {
144158
const ivec3 pos = ivec3(gl_GlobalInvocationID);
@@ -177,6 +191,84 @@ void quantize_per_token() {
177191
write_texel(t_out, pos, outtex);
178192
}
179193

194+
#else // per_channel
195+
196+
void quantize_per_channel() {
197+
const ivec3 pos = ivec3(gl_GlobalInvocationID);
198+
199+
if (any(greaterThanEqual(pos, t_in_limits))) {
200+
return;
201+
}
202+
203+
FVEC4_T intex = load_texel(t_in, pos);
204+
IVEC4_T outtex;
205+
206+
// Calculate channel index based on the quantization axis (already converted to WHCN)
207+
// The axis parameter is now in WHCN coordinate system:
208+
// axis 0 -> W dimension (pos.x for texture, but width-packed so pos.x * 4 + component)
209+
// axis 1 -> H dimension (pos.y)
210+
// axis 2 -> C dimension (pos.z / C), but for 4D tensors this includes batch-channel folding
211+
// axis 3 -> N dimension (pos.z / N), but for 4D tensors this includes batch-channel folding
212+
213+
if (axis == 0) {
214+
// Width dimension - each texel component has different channel index
215+
[[unroll]] for (int i = 0; i < 4; ++i) {
216+
IN_T value = IN_T(intex[i]);
217+
int channel_idx = pos.x * 4 + i;
218+
channel_idx = min(channel_idx, num_channels - 1);
219+
220+
float scale_val = t_scale[channel_idx];
221+
int zero_point_val = t_zero_point[channel_idx];
222+
OUT_T qvalue = quantize_val(value, scale_val, zero_point_val);
223+
outtex[i] = qvalue;
224+
}
225+
} else if (axis == 1) {
226+
// Height dimension - all texel components use same channel index
227+
int channel_idx = pos.y;
228+
channel_idx = min(channel_idx, num_channels - 1);
229+
float scale_val = t_scale[channel_idx];
230+
int zero_point_val = t_zero_point[channel_idx];
231+
232+
[[unroll]] for (int i = 0; i < 4; ++i) {
233+
IN_T value = IN_T(intex[i]);
234+
OUT_T qvalue = quantize_val(value, scale_val, zero_point_val);
235+
outtex[i] = qvalue;
236+
}
237+
} else if (axis == 2) {
238+
// Channel dimension - for 4D tensors, need to account for batch-channel folding
239+
// The Z coordinate contains folded batch*channel information
240+
// We need to extract the actual channel index from the folded dimension
241+
int folded_idx = pos.z;
242+
int channel_idx = folded_idx % num_channels;
243+
244+
float scale_val = t_scale[channel_idx];
245+
int zero_point_val = t_zero_point[channel_idx];
246+
247+
[[unroll]] for (int i = 0; i < 4; ++i) {
248+
IN_T value = IN_T(intex[i]);
249+
OUT_T qvalue = quantize_val(value, scale_val, zero_point_val);
250+
outtex[i] = qvalue;
251+
}
252+
} else if (axis == 3) {
253+
// Batch dimension - for 4D tensors, need to account for batch-channel folding
254+
// The Z coordinate contains folded batch*channel information
255+
// We need to extract the actual batch index from the folded dimension
256+
int folded_idx = pos.z;
257+
int batch_idx = folded_idx / num_channels;
258+
259+
float scale_val = t_scale[batch_idx];
260+
int zero_point_val = t_zero_point[batch_idx];
261+
262+
[[unroll]] for (int i = 0; i < 4; ++i) {
263+
IN_T value = IN_T(intex[i]);
264+
OUT_T qvalue = quantize_val(value, scale_val, zero_point_val);
265+
outtex[i] = qvalue;
266+
}
267+
}
268+
269+
write_texel(t_out, pos, outtex);
270+
}
271+
180272
#endif
181273

182274
void main() {

backends/vulkan/runtime/graph/ops/glsl/quantize_texture.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,3 +17,5 @@ quantize_texture:
1717
MODE: per_tensor
1818
- NAME: quantize_per_token_texture3d
1919
MODE: per_token
20+
- NAME: quantize_per_channel_texture3d
21+
MODE: per_channel

0 commit comments

Comments
 (0)