@@ -45,6 +45,16 @@ $if MODE == "per_token":
4545 int quant_min;
4646 int quant_max;
4747 };
48+ $if MODE == "per_channel":
49+ ${layout_declare_tensor(B, "r", "t_scale", "float ", "buffer ")}
50+ ${layout_declare_tensor(B, "r", "t_zero_point", "int ", "buffer ")}
51+
52+ layout (push_constant) uniform restrict Block {
53+ int axis;
54+ int num_channels;
55+ int quant_min;
56+ int quant_max;
57+ };
4858
4959${layout_declare_ubo(B, "ivec3 ", "t_in_limits")}
5060${layout_declare_ubo(B, "ivec3 ", "t_out_limits")}
@@ -147,7 +157,7 @@ void dequantize_per_tensor() {
147157 write_texel(t_out, pos, outtex);
148158}
149159
150- #else
160+ #elif defined(per_token)
151161
152162void dequantize_per_token() {
153163 const ivec3 pos = ivec3 (gl_GlobalInvocationID);
@@ -189,6 +199,97 @@ void dequantize_per_token() {
189199 write_texel(t_out, pos, outtex);
190200}
191201
202+ #else // per_channel
203+
204+ void dequantize_per_channel() {
205+ const ivec3 pos = ivec3 (gl_GlobalInvocationID);
206+
207+ if (any (greaterThanEqual (pos, t_in_limits))) {
208+ return ;
209+ }
210+
211+ IVEC4_T intex = load_texel(t_in, pos);
212+ FVEC4_T outtex;
213+
214+ // Calculate channel index based on the dequantization axis (already converted to WHCN)
215+ // The axis parameter is now in WHCN coordinate system:
216+ // axis 0 -> W dimension (pos.x)
217+ // axis 1 -> H dimension (pos.y)
218+ // axis 2 -> C dimension (pos.z)
219+ // axis 3 -> N dimension (batch folding in texture storage)
220+
221+ if (axis == 0 ) {
222+ // Width dimension - each texel component has different channel index
223+ [[unroll]] for (int i = 0 ; i < 4 ; ++ i) {
224+ IN_T qvalue = IN_T(intex[i]);
225+ int channel_idx = pos.x * 4 + i;
226+ channel_idx = min (channel_idx, num_channels - 1 );
227+
228+ float scale_val = t_scale[channel_idx];
229+ int zero_point_val = t_zero_point[channel_idx];
230+ OUT_T value = dequantize_val(qvalue, scale_val, zero_point_val);
231+ $if OUT_DTYPE == "double ":
232+ outtex[i] = float (value);
233+ $else :
234+ outtex[i] = value;
235+ }
236+ } else if (axis == 1 ) {
237+ int channel_idx = pos.y;
238+ channel_idx = min (channel_idx, num_channels - 1 );
239+ float scale_val = t_scale[channel_idx];
240+ int zero_point_val = t_zero_point[channel_idx];
241+
242+ [[unroll]] for (int i = 0 ; i < 4 ; ++ i) {
243+ IN_T qvalue = IN_T(intex[i]);
244+ OUT_T value = dequantize_val(qvalue, scale_val, zero_point_val);
245+ $if OUT_DTYPE == "double ":
246+ outtex[i] = float (value);
247+ $else :
248+ outtex[i] = value;
249+ }
250+ } else if (axis == 2 ) {
251+ // Channel dimension - for 4D tensors, need to account for batch-channel folding
252+ // The Z coordinate contains folded batch*channel information
253+ // We need to extract the actual channel index from the folded dimension
254+ int folded_idx = pos.z;
255+ int channel_idx = folded_idx % num_channels;
256+
257+ float scale_val = t_scale[channel_idx];
258+ int zero_point_val = t_zero_point[channel_idx];
259+
260+ [[unroll]] for (int i = 0 ; i < 4 ; ++ i) {
261+ IN_T qvalue = IN_T(intex[i]);
262+ OUT_T value = dequantize_val(qvalue, scale_val, zero_point_val);
263+ $if OUT_DTYPE == "double ":
264+ outtex[i] = float (value);
265+ $else :
266+ outtex[i] = value;
267+ }
268+ } else if (axis == 3 ) {
269+ // Batch dimension - for 4D tensors, need to account for batch-channel folding
270+ // The Z coordinate contains folded batch*channel information
271+ // We need to extract the actual channel index from the folded dimension
272+ int folded_idx = pos.z;
273+ // In this case num_channels actually corresponds to the number of channels
274+ // the C dimension N(C)HW
275+ int channel_idx = folded_idx / num_channels;
276+
277+ float scale_val = t_scale[channel_idx];
278+ int zero_point_val = t_zero_point[channel_idx];
279+
280+ [[unroll]] for (int i = 0 ; i < 4 ; ++ i) {
281+ IN_T qvalue = IN_T(intex[i]);
282+ OUT_T value = dequantize_val(qvalue, scale_val, zero_point_val);
283+ $if OUT_DTYPE == "double ":
284+ outtex[i] = float (value);
285+ $else :
286+ outtex[i] = value;
287+ }
288+ }
289+
290+ write_texel(t_out, pos, outtex);
291+ }
292+
192293#endif
193294
194295void main() {
0 commit comments