Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion backends/vulkan/op_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,6 @@ def register_affine_quantization_op():

@update_features(
[
exir_ops.edge.torchao.choose_qparams_affine.default,
exir_ops.edge.quantized_decomposed.choose_qparams.tensor,
exir_ops.edge.quantized_decomposed.choose_qparams_per_token_asymmetric.default,
]
Expand All @@ -184,6 +183,20 @@ def register_torchao_quantization_op():
)


@update_features(
exir_ops.edge.torchao.choose_qparams_affine.default,
)
def register_torchao_choose_qparams_affine():
return OpFeatures(
inputs_storage=utils.CONTIGUOUS_ANY,
outputs_storage=[
utils.CONTIGUOUS_BUFFER, # scales
utils.CONTIGUOUS_BUFFER, # zero_points
],
supports_resize=True,
)


@update_features(
[
exir_ops.edge.aten.add.Tensor,
Expand Down
51 changes: 48 additions & 3 deletions backends/vulkan/runtime/graph/ComputeGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,16 @@ bool ComputeGraph::is_contiguous_buffer_tensor(const ValueRef idx) const {
return is_contiguous(idx);
}

bool ComputeGraph::is_contiguous_texture_tensor(const ValueRef idx) const {
if (!val_is_tensor(idx)) {
return false;
}
if (is_buffer_storage(idx)) {
return false;
}
return has_standard_axis_map(idx) && packed_dim_of(idx) == 0;
}

bool ComputeGraph::is_standard_channels_packed_texture_tensor(
const ValueRef idx) const {
if (!val_is_tensor(idx)) {
Expand All @@ -343,15 +353,50 @@ bool ComputeGraph::is_standard_channels_packed_texture_tensor(
return has_standard_axis_map(idx) && packed_dim_of(idx) == 2;
}

bool ComputeGraph::is_standard_width_packed_texture_tensor(
bool ComputeGraph::is_2d_matrix(const ValueRef idx) const {
std::vector<int64_t> sizes = sizes_of(idx);
const size_t ndim = sizes.size();
if (sizes.size() < 2) {
return false;
}
if (sizes.size() == 2) {
return true;
}

// Check that outermost dims have size of 1
for (int d = 0; d < ndim - 2; d++) {
if (sizes[d] != 1) {
return false;
}
}

return true;
}

bool ComputeGraph::is_vectorizable_contiguous_2d_matrix(
const ValueRef idx) const {
if (!val_is_tensor(idx)) {
if (!is_2d_matrix(idx)) {
return false;
}
if (is_buffer_storage(idx)) {
return is_contiguous_buffer_tensor(idx) &&
size_at<int32_t>(-1, idx) % 4 == 0;
}
return is_contiguous_texture_tensor(idx);
}

bool ComputeGraph::is_vectorizable_width_packed_tensor(
const ValueRef idx) const {
// Not a tensor - return false
if (!val_is_tensor(idx)) {
return false;
}
return has_standard_axis_map(idx) && packed_dim_of(idx) == 0;
if (is_buffer_storage(idx)) {
return is_contiguous_buffer_tensor(idx) &&
size_at<int32_t>(-1, idx) % 4 == 0;
}

return is_standard_channels_packed_texture_tensor(idx);
}

ValueRef ComputeGraph::add_tensor(
Expand Down
30 changes: 26 additions & 4 deletions backends/vulkan/runtime/graph/ComputeGraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -382,18 +382,40 @@ class ComputeGraph final {
* 1. The value at `idx` is a tensor
* 2. The tensor at `idx` has texture storage
* 3. The texture backed tensor at `idx` has a standard axis mapping
* 4. The texture backed tensor at `idx` is channels packed
* 4. The texture backed tensor at `idx` is width packed
*/
bool is_standard_channels_packed_texture_tensor(const ValueRef idx) const;
bool is_contiguous_texture_tensor(const ValueRef idx) const;

/*
* Checks that the following is true:
* 1. The value at `idx` is a tensor
* 2. The tensor at `idx` has texture storage
* 3. The texture backed tensor at `idx` has a standard axis mapping
* 4. The texture backed tensor at `idx` is width packed
* 4. The texture backed tensor at `idx` is channels packed
*/
bool is_standard_channels_packed_texture_tensor(const ValueRef idx) const;

/*
* Checks that the value at `idx` is either a 2D tensor, or if the tensor has
* more than 2 dims, the outermost dims have size of 1, i.e. can be squeezed
* to be a 2D tensor.
*/
bool is_2d_matrix(const ValueRef idx) const;

/*
* Same as the above, but also requires that the tensor is a contiguous
* buffer with a width divisible by 4 or a standard width packed texture.
*/
bool is_vectorizable_contiguous_2d_matrix(const ValueRef idx) const;

/*
* Checks that the following is true:
* 1. The value at `idx` is a tensor
* 2. The tensor at `idx` is width packed
* 3. The tensor at `idx` has a standard axis mapping or is a contiguous
* buffer
*/
bool is_standard_width_packed_texture_tensor(const ValueRef idx) const;
bool is_vectorizable_width_packed_tensor(const ValueRef idx) const;

inline bool val_is_view_of(const ValueRef maybe_view, const ValueRef base)
const {
Expand Down
184 changes: 184 additions & 0 deletions backends/vulkan/runtime/graph/ops/glsl/choose_qparams_per_row.glsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#version 450 core

#define PRECISION ${PRECISION}
#define VEC4_T ${texel_load_type(DTYPE, STORAGE)}
#define T ${texel_load_component_type(DTYPE, STORAGE)}

#define NUM_OUTPUTS_PER_WG ${NUM_OUTPUTS_PER_WG}
#define NUM_WORKERS_PER_OUTPUT ${NUM_WORKERS_PER_OUTPUT}

// Maximum total threads in a work group
#define MAX_THREADS 256

${define_active_storage_type(STORAGE)}
${define_required_extensions("int8")}

#extension GL_EXT_control_flow_attributes : require

layout(std430) buffer;

#include "common.glslh"

${layout_declare_tensor(B, "w", "t_scales", "float", "buffer")}
${layout_declare_tensor(B, "w", "t_zps", "int", "buffer")}
${layout_declare_tensor(B, "r", "t_input", DTYPE, STORAGE, is_scalar_array=False)}

${layout_declare_ubo(B, "ivec4", "input_sizes")}

layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;

layout(push_constant) uniform PushConstants {
int quant_min;
int quant_max;
};

// Shared memory for cooperative min/max finding
shared T shared_min[NUM_OUTPUTS_PER_WG][NUM_WORKERS_PER_OUTPUT];
shared T shared_max[NUM_OUTPUTS_PER_WG][NUM_WORKERS_PER_OUTPUT];

const float SMALL_SCALE_THRESHOLD = 6.1e-5;

void calculate_scale_and_zero_point(
float min_val,
float max_val,
int qmin,
int qmax,
out float scale,
out int8_t zero_point) {

// Extend the [min, max] interval to ensure it contains 0
min_val = min(min_val, 0.0);
max_val = max(max_val, 0.0);

// Calculate scale
scale = (max_val - min_val) / float(qmax - qmin);

// Handle special cases for scale
if (scale == 0.0 || isinf(1.0 / scale)) {
scale = 0.1;
}

// Cut off small scale
if (scale < SMALL_SCALE_THRESHOLD) {
float org_scale = scale;
scale = SMALL_SCALE_THRESHOLD;
// Adjust the min and max based on the new scale
if (min_val == 0.0) {
max_val = SMALL_SCALE_THRESHOLD * float(qmax - qmin);
} else if (max_val == 0.0) {
min_val = -SMALL_SCALE_THRESHOLD * float(qmax - qmin);
} else {
float amplifier = SMALL_SCALE_THRESHOLD / org_scale;
min_val *= amplifier;
max_val *= amplifier;
}
}

// Zero-point computation
float zero_point_from_min = float(qmin) - min_val / scale;
float zero_point_from_max = float(qmax) - max_val / scale;
float zero_point_from_min_error = abs(float(qmin)) - abs(min_val / scale);
float zero_point_from_max_error = abs(float(qmax)) - abs(max_val / scale);

float initial_zero_point = zero_point_from_min_error < zero_point_from_max_error
? zero_point_from_min
: zero_point_from_max;

// Nudge zero point to be an integer
int nudged_zero_point;
if (initial_zero_point < float(qmin)) {
nudged_zero_point = qmin;
} else if (initial_zero_point > float(qmax)) {
nudged_zero_point = qmax;
} else {
nudged_zero_point = int(round(initial_zero_point));
}

zero_point = int8_t(nudged_zero_point);
}

#ifdef USING_BUFFER

VEC4_T load_input_x4(const int x4, const int y, const int ntexels_x) {
return t_input[(y * ntexels_x) + x4];
}

#else // USING_TEXTURE

VEC4_T load_input_x4(const int x4, const int y, const int ntexels_x) {
return texelFetch(t_input, ivec3(x4, y, 0), 0);
}

#endif // USING_BUFFER

void main() {
const int worker_id = int(gl_LocalInvocationID.x);
const int output_id = int(gl_LocalInvocationID.y);

const int output_y = int(gl_GlobalInvocationID.y);

if (output_y >= input_sizes.y) {
return;
}

// Input is 2D tensor (height x width), width-packed
// Each channel corresponds to a row in the tensor
const int X4 = div_4(input_sizes.x);

// Initialize thread-local min/max
float local_min = 1e30;
float local_max = -1e30;

// Each thread processes elements along their assigned output_id with stride
// NUM_WORKERS_PER_OUTPUT
for (int x4 = worker_id; x4 < X4; x4 += NUM_WORKERS_PER_OUTPUT) {
VEC4_T in_texel = load_input_x4(x4, output_y, X4);
for (int i = 0; i < 4; i++) {
local_min = min(local_min, in_texel[i]);
local_max = max(local_max, in_texel[i]);
}
}

// Store thread-local results in shared memory
shared_min[output_id][worker_id] = local_min;
shared_max[output_id][worker_id] = local_max;

memoryBarrierShared();
barrier();

// Tree reduction to compute the overall result
for (int i = NUM_WORKERS_PER_OUTPUT / 2; i > 0; i >>= 1) {
if (worker_id < i) {
shared_min[output_id][worker_id] = min(
shared_min[output_id][worker_id],
shared_min[output_id][worker_id + i]);
shared_max[output_id][worker_id] = max(
shared_max[output_id][worker_id],
shared_max[output_id][worker_id + i]);
}
memoryBarrierShared();
barrier();
}

// Only first thread will write out result
if (worker_id == 0) {
local_min = shared_min[output_id][0];
local_max = shared_max[output_id][0];

float scale;
int8_t zero_point;
calculate_scale_and_zero_point(
local_min, local_max, quant_min, quant_max, scale, zero_point);

t_scales[output_y] = scale;
t_zps[output_y] = zero_point;
}
}
23 changes: 23 additions & 0 deletions backends/vulkan/runtime/graph/ops/glsl/choose_qparams_per_row.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

choose_qparams_per_row:
parameter_names_with_default_values:
DTYPE: float
STORAGE: texture3d
NUM_OUTPUTS_PER_WG: 1
NUM_WORKERS_PER_OUTPUT: 64
generate_variant_forall:
STORAGE:
- VALUE: texture3d
- VALUE: buffer
DTYPE:
- VALUE: float
shader_variants:
- NAME: choose_qparams_per_row_o1w64
- NAME: choose_qparams_per_row_o4w16
NUM_OUTPUTS_PER_WG: 4
NUM_WORKERS_PER_OUTPUT: 16
Loading
Loading