Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 29 additions & 7 deletions backends/vulkan/op_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,13 +221,6 @@ def update_features_impl(op: OpKey):
@update_features(
[
operator.getitem,
# Quantization related ops will be fused via graph passes
exir_ops.edge.quantized_decomposed.quantize_per_channel.default,
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
exir_ops.edge.quantized_decomposed.quantize_per_tensor.tensor,
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.tensor,
exir_ops.edge.quantized_decomposed.dequantize_per_channel.default,
# Symbolic integer ops
torch.ops.aten.sym_size.int,
operator.add,
Expand All @@ -250,6 +243,35 @@ def register_ephemeral_op(features: OpFeatures):
return features


@update_features(
[
exir_ops.edge.quantized_decomposed.quantize_per_channel.default,
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
exir_ops.edge.quantized_decomposed.quantize_per_tensor.tensor,
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.tensor,
exir_ops.edge.quantized_decomposed.dequantize_per_channel.default,
exir_ops.edge.quantized_decomposed.quantize_per_token.default,
exir_ops.edge.quantized_decomposed.dequantize_per_token.default,
exir_ops.edge.quantized_decomposed.choose_qparams.tensor,
exir_ops.edge.quantized_decomposed.choose_qparams_per_token_asymmetric.default,
]
)
def register_quantization_op(features: OpFeatures):
# Quantization requires buffer storage and width packing for scales/zero_points
# but we need to provide texture impl features for the partitioner to work properly
features.texture_impl = TextureImplFeatures(
uses_axis_map=True,
valid_packed_dims={
PackedDim.WIDTH,
},
)
features.buffer_impl = True
features.resize_fn = True
features.optimal_storage = VkStorageType.BUFFER
return features


@update_features(
[
exir_ops.edge.aten.add.Tensor,
Expand Down
4 changes: 4 additions & 0 deletions backends/vulkan/runtime/VulkanBackend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,14 @@ vkapi::ScalarType get_scalar_type(const vkgraph::VkDataType& vk_datatype) {
return vkapi::kChar;
case vkgraph::VkDataType::INT32:
return vkapi::kInt;
case vkgraph::VkDataType::INT64:
return vkapi::kLong;
case vkgraph::VkDataType::FLOAT16:
return vkapi::kHalf;
case vkgraph::VkDataType::FLOAT32:
return vkapi::kFloat;
case vkgraph::VkDataType::FLOAT64:
return vkapi::kDouble;
}
}

Expand Down
16 changes: 7 additions & 9 deletions backends/vulkan/runtime/graph/ops/glsl/choose_qparams.glslh
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,13 @@
#ifndef CHOOSE_QPARAMS_GLSLH
#define CHOOSE_QPARAMS_GLSLH

// equivalent of the eps defined in the cpu implementation
#define SMALL_SCALE_THRESHOLD 6.1e-5

// Calculate scale and zero point from min and max values
void calculate_scale_and_zero_point(
float min_val,
float max_val,
int qmin,
int qmax,
float eps_threshold,
out float scale_val,
out int zero_point_val) {
// ensure we have zero included in our range
Expand All @@ -31,18 +29,18 @@ void calculate_scale_and_zero_point(
scale_val = 0.1;
}

// Cut off small scale
if (scale_val < SMALL_SCALE_THRESHOLD) {
// Cut off small scale using the provided eps threshold
if (scale_val < eps_threshold) {
float org_scale = scale_val;
scale_val = SMALL_SCALE_THRESHOLD;
scale_val = eps_threshold;

// Adjust min and max based on new scale
if (min_val == 0.0) {
max_val = SMALL_SCALE_THRESHOLD * float(qmax - qmin);
max_val = eps_threshold * float(qmax - qmin);
} else if (max_val == 0.0) {
min_val = -SMALL_SCALE_THRESHOLD * float(qmax - qmin);
min_val = -eps_threshold * float(qmax - qmin);
} else {
float amplifier = SMALL_SCALE_THRESHOLD / org_scale;
float amplifier = eps_threshold / org_scale;
min_val *= amplifier;
max_val *= amplifier;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ $if MODE == "per_tensor":
layout(push_constant) uniform restrict Block {
int quant_min;
int quant_max;
float eps;
};
$else:
layout(push_constant) uniform restrict Block {
Expand Down Expand Up @@ -175,7 +176,7 @@ void choose_qparams_per_tensor() {

float scale_val;
int zero_point_val;
calculate_scale_and_zero_point(global_min, global_max, quant_min, quant_max, scale_val, zero_point_val);
calculate_scale_and_zero_point(global_min, global_max, quant_min, quant_max, eps, scale_val, zero_point_val);

t_scale[0] = scale_val;
t_zero_point[0] = zero_point_val;
Expand Down Expand Up @@ -260,7 +261,7 @@ void choose_qparams_per_token() {

float scale_val;
int zero_point_val;
calculate_scale_and_zero_point(token_min, token_max, quant_min, quant_max, scale_val, zero_point_val);
calculate_scale_and_zero_point(token_min, token_max, quant_min, quant_max, 1e-5, scale_val, zero_point_val);

t_scale[token_id] = scale_val;
t_zero_point[token_id] = zero_point_val;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ $if MODE == "per_tensor":
layout(push_constant) uniform restrict Block {
int quant_min;
int quant_max;
float eps;
};
$else:
layout(push_constant) uniform restrict Block {
Expand Down Expand Up @@ -234,7 +235,7 @@ void choose_qparams_per_tensor() {

float scale_val;
int zero_point_val;
calculate_scale_and_zero_point(global_min, global_max, quant_min, quant_max, scale_val, zero_point_val);
calculate_scale_and_zero_point(global_min, global_max, quant_min, quant_max, eps, scale_val, zero_point_val);

write_texel(t_scale, ivec3(0, 0, 0), vec4(scale_val, 0.0, 0.0, 0.0));
write_texel(t_zero_point, ivec3(0, 0, 0), ivec4(zero_point_val, 0, 0, 0));
Expand Down Expand Up @@ -372,7 +373,7 @@ void choose_qparams_per_token() {

float scale_val;
int zero_point_val;
calculate_scale_and_zero_point(token_min, token_max, quant_min, quant_max, scale_val, zero_point_val);
calculate_scale_and_zero_point(token_min, token_max, quant_min, quant_max, 1e-5, scale_val, zero_point_val);

// Convert token_id to 3D coordinates for output texture
// Assuming output tensors have the same layout as input but with different dimensions
Expand Down
51 changes: 50 additions & 1 deletion backends/vulkan/runtime/graph/ops/glsl/dequantize_buffer.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,16 @@ $if MODE == "per_token":
int quant_min;
int quant_max;
};
$if MODE == "per_channel":
${layout_declare_tensor(B, "r", "t_scale", "float", "buffer")}
${layout_declare_tensor(B, "r", "t_zero_point", "int", "buffer")}

layout(push_constant) uniform restrict Block {
int axis;
int num_channels;
int quant_min;
int quant_max;
};

${layout_declare_ubo(B, "int", "out_numel")}
${layout_declare_ubo(B, "ivec4", "t_in_sizes")}
Expand Down Expand Up @@ -141,7 +151,7 @@ void dequantize_per_tensor() {
t_out[out_bufi] = value;
}

#else
#elif defined(per_token)

void dequantize_per_token() {
const int out_bufi = int(gl_GlobalInvocationID.x);
Expand Down Expand Up @@ -176,6 +186,45 @@ void dequantize_per_token() {
t_out[out_bufi] = value;
}

#else // per_channel

void dequantize_per_channel() {
const int out_bufi = int(gl_GlobalInvocationID.x);

if (out_bufi >= out_numel) {
return;
}

const ivec4 out_tidx = bufi_to_tidx(out_bufi, t_out_strides, out_dim_order);
const int in_bufi = tidx_to_bufi(out_tidx, t_in_strides);

IN_T qvalue = t_in[in_bufi];

// Calculate channel index based on the dequantization axis (already converted to WHCN)
// The axis parameter is now in WHCN coordinate system:
// axis 0 -> W dimension (tidx.x)
// axis 1 -> H dimension (tidx.y)
// axis 2 -> C dimension (tidx.z)
// axis 3 -> N dimension (tidx.w)
int channel_idx = 0;

if (axis == 0) {
channel_idx = out_tidx.x;
} else if (axis == 1) {
channel_idx = out_tidx.y;
} else if (axis == 2) {
channel_idx = out_tidx.z;
} else if (axis == 3) {
channel_idx = out_tidx.w;
}

channel_idx = min(channel_idx, num_channels - 1);

OUT_T value = dequantize_val(qvalue, t_scale[channel_idx], t_zero_point[channel_idx]);

t_out[out_bufi] = value;
}

#endif

void main() {
Expand Down
2 changes: 2 additions & 0 deletions backends/vulkan/runtime/graph/ops/glsl/dequantize_buffer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,5 @@ dequantize_buffer:
MODE: per_tensor
- NAME: dequantize_per_token_buffer
MODE: per_token
- NAME: dequantize_per_channel_buffer
MODE: per_channel
103 changes: 102 additions & 1 deletion backends/vulkan/runtime/graph/ops/glsl/dequantize_texture.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,16 @@ $if MODE == "per_token":
int quant_min;
int quant_max;
};
$if MODE == "per_channel":
${layout_declare_tensor(B, "r", "t_scale", "float", "buffer")}
${layout_declare_tensor(B, "r", "t_zero_point", "int", "buffer")}

layout(push_constant) uniform restrict Block {
int axis;
int num_channels;
int quant_min;
int quant_max;
};

${layout_declare_ubo(B, "ivec3", "t_in_limits")}
${layout_declare_ubo(B, "ivec3", "t_out_limits")}
Expand Down Expand Up @@ -147,7 +157,7 @@ void dequantize_per_tensor() {
write_texel(t_out, pos, outtex);
}

#else
#elif defined(per_token)

void dequantize_per_token() {
const ivec3 pos = ivec3(gl_GlobalInvocationID);
Expand Down Expand Up @@ -189,6 +199,97 @@ void dequantize_per_token() {
write_texel(t_out, pos, outtex);
}

#else // per_channel

void dequantize_per_channel() {
const ivec3 pos = ivec3(gl_GlobalInvocationID);

if (any(greaterThanEqual(pos, t_in_limits))) {
return;
}

IVEC4_T intex = load_texel(t_in, pos);
FVEC4_T outtex;

// Calculate channel index based on the dequantization axis (already converted to WHCN)
// The axis parameter is now in WHCN coordinate system:
// axis 0 -> W dimension (pos.x)
// axis 1 -> H dimension (pos.y)
// axis 2 -> C dimension (pos.z)
// axis 3 -> N dimension (batch folding in texture storage)

if (axis == 0) {
// Width dimension - each texel component has different channel index
[[unroll]] for (int i = 0; i < 4; ++i) {
IN_T qvalue = IN_T(intex[i]);
int channel_idx = pos.x * 4 + i;
channel_idx = min(channel_idx, num_channels - 1);

float scale_val = t_scale[channel_idx];
int zero_point_val = t_zero_point[channel_idx];
OUT_T value = dequantize_val(qvalue, scale_val, zero_point_val);
$if OUT_DTYPE == "double":
outtex[i] = float(value);
$else:
outtex[i] = value;
}
} else if (axis == 1) {
int channel_idx = pos.y;
channel_idx = min(channel_idx, num_channels - 1);
float scale_val = t_scale[channel_idx];
int zero_point_val = t_zero_point[channel_idx];

[[unroll]] for (int i = 0; i < 4; ++i) {
IN_T qvalue = IN_T(intex[i]);
OUT_T value = dequantize_val(qvalue, scale_val, zero_point_val);
$if OUT_DTYPE == "double":
outtex[i] = float(value);
$else:
outtex[i] = value;
}
} else if (axis == 2) {
// Channel dimension - for 4D tensors, need to account for batch-channel folding
// The Z coordinate contains folded batch*channel information
// We need to extract the actual channel index from the folded dimension
int folded_idx = pos.z;
int channel_idx = folded_idx % num_channels;

float scale_val = t_scale[channel_idx];
int zero_point_val = t_zero_point[channel_idx];

[[unroll]] for (int i = 0; i < 4; ++i) {
IN_T qvalue = IN_T(intex[i]);
OUT_T value = dequantize_val(qvalue, scale_val, zero_point_val);
$if OUT_DTYPE == "double":
outtex[i] = float(value);
$else:
outtex[i] = value;
}
} else if (axis == 3) {
// Batch dimension - for 4D tensors, need to account for batch-channel folding
// The Z coordinate contains folded batch*channel information
// We need to extract the actual channel index from the folded dimension
int folded_idx = pos.z;
// In this case num_channels actually corresponds to the number of channels
// the C dimension N(C)HW
int channel_idx = folded_idx / num_channels;

float scale_val = t_scale[channel_idx];
int zero_point_val = t_zero_point[channel_idx];

[[unroll]] for (int i = 0; i < 4; ++i) {
IN_T qvalue = IN_T(intex[i]);
OUT_T value = dequantize_val(qvalue, scale_val, zero_point_val);
$if OUT_DTYPE == "double":
outtex[i] = float(value);
$else:
outtex[i] = value;
}
}

write_texel(t_out, pos, outtex);
}

#endif

void main() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,5 @@ dequantize_texture:
MODE: per_tensor
- NAME: dequantize_per_token_texture3d
MODE: per_token
- NAME: dequantize_per_channel_texture3d
MODE: per_channel
Loading
Loading