@@ -26,6 +26,8 @@ ${define_required_extensions(OUT_DTYPE)}
2626
2727layout (std430) buffer ;
2828
29+ #include "indexing_utils.h"
30+
2931${layout_declare_tensor(B, "w", "t_out", OUT_DTYPE, "texture3d")}
3032${layout_declare_tensor(B, "r", "t_in", IN_DTYPE, "texture3d")}
3133
@@ -45,11 +47,23 @@ $if MODE == "per_token":
4547 int quant_min;
4648 int quant_max;
4749 };
50+ $if MODE == "per_channel":
51+ ${layout_declare_tensor(B, "r", "t_scale", "float ", "buffer ")}
52+ ${layout_declare_tensor(B, "r", "t_zero_point", "int ", "buffer ")}
53+
54+ layout (push_constant) uniform restrict Block {
55+ int axis;
56+ int num_channels;
57+ int quant_min;
58+ int quant_max;
59+ };
4860
4961${layout_declare_ubo(B, "ivec3 ", "t_in_limits")}
5062${layout_declare_ubo(B, "ivec3 ", "t_out_limits")}
5163
52- #include "indexing_utils.h"
64+ ${layout_declare_spec_const(C, "int ", "out_layout", "DEFAULT_LAYOUT")}
65+ ${layout_declare_spec_const(C, "int ", "in_layout", "DEFAULT_LAYOUT")}
66+
5367#include "quantize.glslh"
5468
5569layout (local_size_x_id = 0 , local_size_y_id = 1 , local_size_z_id = 2 ) in ;
@@ -138,7 +152,7 @@ void quantize_per_tensor() {
138152 write_texel(t_out, pos, outtex);
139153}
140154
141- #else
155+ #elif defined(per_token)
142156
143157void quantize_per_token() {
144158 const ivec3 pos = ivec3 (gl_GlobalInvocationID);
@@ -177,6 +191,84 @@ void quantize_per_token() {
177191 write_texel(t_out, pos, outtex);
178192}
179193
194+ #else // per_channel
195+
196+ void quantize_per_channel() {
197+ const ivec3 pos = ivec3 (gl_GlobalInvocationID);
198+
199+ if (any (greaterThanEqual (pos, t_in_limits))) {
200+ return ;
201+ }
202+
203+ FVEC4_T intex = load_texel(t_in, pos);
204+ IVEC4_T outtex;
205+
206+ // Calculate channel index based on the quantization axis (already converted to WHCN)
207+ // The axis parameter is now in WHCN coordinate system:
208+ // axis 0 -> W dimension (pos.x for texture, but width-packed so pos.x * 4 + component)
209+ // axis 1 -> H dimension (pos.y)
210+ // axis 2 -> C dimension (pos.z / C), but for 4D tensors this includes batch-channel folding
211+ // axis 3 -> N dimension (pos.z / N), but for 4D tensors this includes batch-channel folding
212+
213+ if (axis == 0 ) {
214+ // Width dimension - each texel component has different channel index
215+ [[unroll]] for (int i = 0 ; i < 4 ; ++ i) {
216+ IN_T value = IN_T(intex[i]);
217+ int channel_idx = pos.x * 4 + i;
218+ channel_idx = min (channel_idx, num_channels - 1 );
219+
220+ float scale_val = t_scale[channel_idx];
221+ int zero_point_val = t_zero_point[channel_idx];
222+ OUT_T qvalue = quantize_val(value, scale_val, zero_point_val);
223+ outtex[i] = qvalue;
224+ }
225+ } else if (axis == 1 ) {
226+ // Height dimension - all texel components use same channel index
227+ int channel_idx = pos.y;
228+ channel_idx = min (channel_idx, num_channels - 1 );
229+ float scale_val = t_scale[channel_idx];
230+ int zero_point_val = t_zero_point[channel_idx];
231+
232+ [[unroll]] for (int i = 0 ; i < 4 ; ++ i) {
233+ IN_T value = IN_T(intex[i]);
234+ OUT_T qvalue = quantize_val(value, scale_val, zero_point_val);
235+ outtex[i] = qvalue;
236+ }
237+ } else if (axis == 2 ) {
238+ // Channel dimension - for 4D tensors, need to account for batch-channel folding
239+ // The Z coordinate contains folded batch*channel information
240+ // We need to extract the actual channel index from the folded dimension
241+ int folded_idx = pos.z;
242+ int channel_idx = folded_idx % num_channels;
243+
244+ float scale_val = t_scale[channel_idx];
245+ int zero_point_val = t_zero_point[channel_idx];
246+
247+ [[unroll]] for (int i = 0 ; i < 4 ; ++ i) {
248+ IN_T value = IN_T(intex[i]);
249+ OUT_T qvalue = quantize_val(value, scale_val, zero_point_val);
250+ outtex[i] = qvalue;
251+ }
252+ } else if (axis == 3 ) {
253+ // Batch dimension - for 4D tensors, need to account for batch-channel folding
254+ // The Z coordinate contains folded batch*channel information
255+ // We need to extract the actual batch index from the folded dimension
256+ int folded_idx = pos.z;
257+ int batch_idx = folded_idx / num_channels;
258+
259+ float scale_val = t_scale[batch_idx];
260+ int zero_point_val = t_zero_point[batch_idx];
261+
262+ [[unroll]] for (int i = 0 ; i < 4 ; ++ i) {
263+ IN_T value = IN_T(intex[i]);
264+ OUT_T qvalue = quantize_val(value, scale_val, zero_point_val);
265+ outtex[i] = qvalue;
266+ }
267+ }
268+
269+ write_texel(t_out, pos, outtex);
270+ }
271+
180272#endif
181273
182274void main() {
0 commit comments