Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 50 additions & 1 deletion backends/vulkan/runtime/graph/ops/glsl/quantize_buffer.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,16 @@ $if MODE == "per_token":
int quant_min;
int quant_max;
};
$if MODE == "per_channel":
${layout_declare_tensor(B, "r", "t_scale", "float", "buffer")}
${layout_declare_tensor(B, "r", "t_zero_point", "int", "buffer")}

layout(push_constant) uniform restrict Block {
int axis;
int num_channels;
int quant_min;
int quant_max;
};

${layout_declare_ubo(B, "int", "out_numel")}
${layout_declare_ubo(B, "ivec4", "t_in_sizes")}
Expand Down Expand Up @@ -137,7 +147,7 @@ void quantize_per_tensor() {
t_out[out_bufi] = qvalue;
}

#else
#elif defined(per_token)

void quantize_per_token() {
const int out_bufi = int(gl_GlobalInvocationID.x);
Expand Down Expand Up @@ -172,6 +182,45 @@ void quantize_per_token() {
t_out[out_bufi] = qvalue;
}

#else // per_channel

void quantize_per_channel() {
const int out_bufi = int(gl_GlobalInvocationID.x);

if (out_bufi >= out_numel) {
return;
}

const ivec4 out_tidx = bufi_to_tidx(out_bufi, t_out_strides, out_dim_order);
const int in_bufi = tidx_to_bufi(out_tidx, t_in_strides);

IN_T value = t_in[in_bufi];

// Calculate channel index based on the quantization axis (already converted to WHCN)
// The axis parameter is now in WHCN coordinate system:
// axis 0 -> W dimension (tidx.x)
// axis 1 -> H dimension (tidx.y)
// axis 2 -> C dimension (tidx.z)
// axis 3 -> N dimension (tidx.w)
int channel_idx = 0;

if (axis == 0) {
channel_idx = out_tidx.x;
} else if (axis == 1) {
channel_idx = out_tidx.y;
} else if (axis == 2) {
channel_idx = out_tidx.z;
} else if (axis == 3) {
channel_idx = out_tidx.w;
}

channel_idx = min(channel_idx, num_channels - 1);

OUT_T qvalue = quantize_val(value, t_scale[channel_idx], t_zero_point[channel_idx]);

t_out[out_bufi] = qvalue;
}

#endif

void main() {
Expand Down
2 changes: 2 additions & 0 deletions backends/vulkan/runtime/graph/ops/glsl/quantize_buffer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,5 @@ quantize_buffer:
MODE: per_tensor
- NAME: quantize_per_token_buffer
MODE: per_token
- NAME: quantize_per_channel_buffer
MODE: per_channel
96 changes: 94 additions & 2 deletions backends/vulkan/runtime/graph/ops/glsl/quantize_texture.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ ${define_required_extensions(OUT_DTYPE)}

layout(std430) buffer;

#include "indexing_utils.h"

${layout_declare_tensor(B, "w", "t_out", OUT_DTYPE, "texture3d")}
${layout_declare_tensor(B, "r", "t_in", IN_DTYPE, "texture3d")}

Expand All @@ -45,11 +47,23 @@ $if MODE == "per_token":
int quant_min;
int quant_max;
};
$if MODE == "per_channel":
${layout_declare_tensor(B, "r", "t_scale", "float", "buffer")}
${layout_declare_tensor(B, "r", "t_zero_point", "int", "buffer")}

layout(push_constant) uniform restrict Block {
int axis;
int num_channels;
int quant_min;
int quant_max;
};

${layout_declare_ubo(B, "ivec3", "t_in_limits")}
${layout_declare_ubo(B, "ivec3", "t_out_limits")}

#include "indexing_utils.h"
${layout_declare_spec_const(C, "int", "out_layout", "DEFAULT_LAYOUT")}
${layout_declare_spec_const(C, "int", "in_layout", "DEFAULT_LAYOUT")}

#include "quantize.glslh"

layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
Expand Down Expand Up @@ -138,7 +152,7 @@ void quantize_per_tensor() {
write_texel(t_out, pos, outtex);
}

#else
#elif defined(per_token)

void quantize_per_token() {
const ivec3 pos = ivec3(gl_GlobalInvocationID);
Expand Down Expand Up @@ -177,6 +191,84 @@ void quantize_per_token() {
write_texel(t_out, pos, outtex);
}

#else // per_channel

void quantize_per_channel() {
const ivec3 pos = ivec3(gl_GlobalInvocationID);

if (any(greaterThanEqual(pos, t_in_limits))) {
return;
}

FVEC4_T intex = load_texel(t_in, pos);
IVEC4_T outtex;

// Calculate channel index based on the quantization axis (already converted to WHCN)
// The axis parameter is now in WHCN coordinate system:
// axis 0 -> W dimension (pos.x for texture, but width-packed so pos.x * 4 + component)
// axis 1 -> H dimension (pos.y)
// axis 2 -> C dimension (pos.z / C), but for 4D tensors this includes batch-channel folding
// axis 3 -> N dimension (pos.z / N), but for 4D tensors this includes batch-channel folding

if (axis == 0) {
// Width dimension - each texel component has different channel index
[[unroll]] for (int i = 0; i < 4; ++i) {
IN_T value = IN_T(intex[i]);
int channel_idx = pos.x * 4 + i;
channel_idx = min(channel_idx, num_channels - 1);

float scale_val = t_scale[channel_idx];
int zero_point_val = t_zero_point[channel_idx];
OUT_T qvalue = quantize_val(value, scale_val, zero_point_val);
outtex[i] = qvalue;
}
} else if (axis == 1) {
// Height dimension - all texel components use same channel index
int channel_idx = pos.y;
channel_idx = min(channel_idx, num_channels - 1);
float scale_val = t_scale[channel_idx];
int zero_point_val = t_zero_point[channel_idx];

[[unroll]] for (int i = 0; i < 4; ++i) {
IN_T value = IN_T(intex[i]);
OUT_T qvalue = quantize_val(value, scale_val, zero_point_val);
outtex[i] = qvalue;
}
} else if (axis == 2) {
// Channel dimension - for 4D tensors, need to account for batch-channel folding
// The Z coordinate contains folded batch*channel information
// We need to extract the actual channel index from the folded dimension
int folded_idx = pos.z;
int channel_idx = folded_idx % num_channels;

float scale_val = t_scale[channel_idx];
int zero_point_val = t_zero_point[channel_idx];

[[unroll]] for (int i = 0; i < 4; ++i) {
IN_T value = IN_T(intex[i]);
OUT_T qvalue = quantize_val(value, scale_val, zero_point_val);
outtex[i] = qvalue;
}
} else if (axis == 3) {
// Batch dimension - for 4D tensors, need to account for batch-channel folding
// The Z coordinate contains folded batch*channel information
// We need to extract the actual batch index from the folded dimension
int folded_idx = pos.z;
int batch_idx = folded_idx / num_channels;

float scale_val = t_scale[batch_idx];
int zero_point_val = t_zero_point[batch_idx];

[[unroll]] for (int i = 0; i < 4; ++i) {
IN_T value = IN_T(intex[i]);
OUT_T qvalue = quantize_val(value, scale_val, zero_point_val);
outtex[i] = qvalue;
}
}

write_texel(t_out, pos, outtex);
}

#endif

void main() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,5 @@ quantize_texture:
MODE: per_tensor
- NAME: quantize_per_token_texture3d
MODE: per_token
- NAME: quantize_per_channel_texture3d
MODE: per_channel
Loading
Loading