Skip to content
51 changes: 50 additions & 1 deletion backends/vulkan/runtime/graph/ops/glsl/dequantize_buffer.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,16 @@ $if MODE == "per_token":
int quant_min;
int quant_max;
};
$if MODE == "per_channel":
${layout_declare_tensor(B, "r", "t_scale", "float", "buffer")}
${layout_declare_tensor(B, "r", "t_zero_point", "int", "buffer")}

layout(push_constant) uniform restrict Block {
int axis;
int num_channels;
int quant_min;
int quant_max;
};

${layout_declare_ubo(B, "int", "out_numel")}
${layout_declare_ubo(B, "ivec4", "t_in_sizes")}
Expand Down Expand Up @@ -141,7 +151,7 @@ void dequantize_per_tensor() {
t_out[out_bufi] = value;
}

#else
#elif defined(per_token)

void dequantize_per_token() {
const int out_bufi = int(gl_GlobalInvocationID.x);
Expand Down Expand Up @@ -176,6 +186,45 @@ void dequantize_per_token() {
t_out[out_bufi] = value;
}

#else // per_channel

void dequantize_per_channel() {
const int out_bufi = int(gl_GlobalInvocationID.x);

if (out_bufi >= out_numel) {
return;
}

const ivec4 out_tidx = bufi_to_tidx(out_bufi, t_out_strides, out_dim_order);
const int in_bufi = tidx_to_bufi(out_tidx, t_in_strides);

IN_T qvalue = t_in[in_bufi];

// Calculate channel index based on the dequantization axis (already converted to WHCN)
// The axis parameter is now in WHCN coordinate system:
// axis 0 -> W dimension (tidx.x)
// axis 1 -> H dimension (tidx.y)
// axis 2 -> C dimension (tidx.z)
// axis 3 -> N dimension (tidx.w)
int channel_idx = 0;

if (axis == 0) {
channel_idx = out_tidx.x;
} else if (axis == 1) {
channel_idx = out_tidx.y;
} else if (axis == 2) {
channel_idx = out_tidx.z;
} else if (axis == 3) {
channel_idx = out_tidx.w;
}

channel_idx = min(channel_idx, num_channels - 1);

OUT_T value = dequantize_val(qvalue, t_scale[channel_idx], t_zero_point[channel_idx]);

t_out[out_bufi] = value;
}

#endif

void main() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,5 @@ dequantize_buffer:
MODE: per_tensor
- NAME: dequantize_per_token_buffer
MODE: per_token
- NAME: dequantize_per_channel_buffer
MODE: per_channel
103 changes: 102 additions & 1 deletion backends/vulkan/runtime/graph/ops/glsl/dequantize_texture.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,16 @@ $if MODE == "per_token":
int quant_min;
int quant_max;
};
$if MODE == "per_channel":
${layout_declare_tensor(B, "r", "t_scale", "float", "buffer")}
${layout_declare_tensor(B, "r", "t_zero_point", "int", "buffer")}

layout(push_constant) uniform restrict Block {
int axis;
int num_channels;
int quant_min;
int quant_max;
};

${layout_declare_ubo(B, "ivec3", "t_in_limits")}
${layout_declare_ubo(B, "ivec3", "t_out_limits")}
Expand Down Expand Up @@ -147,7 +157,7 @@ void dequantize_per_tensor() {
write_texel(t_out, pos, outtex);
}

#else
#elif defined(per_token)

void dequantize_per_token() {
const ivec3 pos = ivec3(gl_GlobalInvocationID);
Expand Down Expand Up @@ -189,6 +199,97 @@ void dequantize_per_token() {
write_texel(t_out, pos, outtex);
}

#else // per_channel

void dequantize_per_channel() {
const ivec3 pos = ivec3(gl_GlobalInvocationID);

if (any(greaterThanEqual(pos, t_in_limits))) {
return;
}

IVEC4_T intex = load_texel(t_in, pos);
FVEC4_T outtex;

// Calculate channel index based on the dequantization axis (already converted to WHCN)
// The axis parameter is now in WHCN coordinate system:
// axis 0 -> W dimension (pos.x)
// axis 1 -> H dimension (pos.y)
// axis 2 -> C dimension (pos.z)
// axis 3 -> N dimension (batch folding in texture storage)

if (axis == 0) {
// Width dimension - each texel component has different channel index
[[unroll]] for (int i = 0; i < 4; ++i) {
IN_T qvalue = IN_T(intex[i]);
int channel_idx = pos.x * 4 + i;
channel_idx = min(channel_idx, num_channels - 1);

float scale_val = t_scale[channel_idx];
int zero_point_val = t_zero_point[channel_idx];
OUT_T value = dequantize_val(qvalue, scale_val, zero_point_val);
$if OUT_DTYPE == "double":
outtex[i] = float(value);
$else:
outtex[i] = value;
}
} else if (axis == 1) {
int channel_idx = pos.y;
channel_idx = min(channel_idx, num_channels - 1);
float scale_val = t_scale[channel_idx];
int zero_point_val = t_zero_point[channel_idx];

[[unroll]] for (int i = 0; i < 4; ++i) {
IN_T qvalue = IN_T(intex[i]);
OUT_T value = dequantize_val(qvalue, scale_val, zero_point_val);
$if OUT_DTYPE == "double":
outtex[i] = float(value);
$else:
outtex[i] = value;
}
} else if (axis == 2) {
// Channel dimension - for 4D tensors, need to account for batch-channel folding
// The Z coordinate contains folded batch*channel information
// We need to extract the actual channel index from the folded dimension
int folded_idx = pos.z;
int channel_idx = folded_idx % num_channels;

float scale_val = t_scale[channel_idx];
int zero_point_val = t_zero_point[channel_idx];

[[unroll]] for (int i = 0; i < 4; ++i) {
IN_T qvalue = IN_T(intex[i]);
OUT_T value = dequantize_val(qvalue, scale_val, zero_point_val);
$if OUT_DTYPE == "double":
outtex[i] = float(value);
$else:
outtex[i] = value;
}
} else if (axis == 3) {
// Batch dimension - for 4D tensors, need to account for batch-channel folding
// The Z coordinate contains folded batch*channel information
// We need to extract the actual channel index from the folded dimension
int folded_idx = pos.z;
// In this case num_channels actually corresponds to the number of channels
// the C dimension N(C)HW
int channel_idx = folded_idx / num_channels;

float scale_val = t_scale[channel_idx];
int zero_point_val = t_zero_point[channel_idx];

[[unroll]] for (int i = 0; i < 4; ++i) {
IN_T qvalue = IN_T(intex[i]);
OUT_T value = dequantize_val(qvalue, scale_val, zero_point_val);
$if OUT_DTYPE == "double":
outtex[i] = float(value);
$else:
outtex[i] = value;
}
}

write_texel(t_out, pos, outtex);
}

#endif

void main() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,5 @@ dequantize_texture:
MODE: per_tensor
- NAME: dequantize_per_token_texture3d
MODE: per_token
- NAME: dequantize_per_channel_texture3d
MODE: per_channel
Loading
Loading