Skip to content

Commit 6ecf0d1

Browse files
author
morelos
committed
[ET-VK][Ops] quantize_per_tensor.tensor variant
# Context We need a tensor variant for dequantize/quantize operators since that is the expected output of choose_qparams. # Changes This extends the logic that currently exists to support a tensor variant for scales and zeros. Differential Revision: [D77746136](https://our.internmc.facebook.com/intern/diff/D77746136/) [ghstack-poisoned]
1 parent 014e327 commit 6ecf0d1

File tree

6 files changed

+413
-41
lines changed

6 files changed

+413
-41
lines changed

backends/vulkan/runtime/graph/ops/glsl/quantize_buffer.glsl

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,14 @@ ${layout_declare_tensor(B, "w", "t_out", OUT_DTYPE, "buffer")}
2727
${layout_declare_tensor(B, "r", "t_in", IN_DTYPE, "buffer")}
2828

2929
$if MODE == "per_tensor":
30+
$if SHAPE == "tensor":
31+
${layout_declare_tensor(B, "r", "t_scale", "float", "buffer")}
32+
${layout_declare_tensor(B, "r", "t_zero_point", "int", "buffer")}
33+
3034
layout(push_constant) uniform restrict Block {
31-
float scale;
32-
int zero_point;
35+
$if SHAPE == "scalar":
36+
float scale;
37+
int zero_point;
3338
int quant_min;
3439
int quant_max;
3540
};
@@ -142,7 +147,10 @@ void quantize_per_tensor() {
142147
const int in_bufi = tidx_to_bufi(out_tidx, t_in_strides);
143148

144149
IN_T value = t_in[in_bufi];
145-
OUT_T qvalue = quantize_val(value, scale, zero_point);
150+
$if SHAPE == "scalar":
151+
OUT_T qvalue = quantize_val(value, scale, zero_point);
152+
$if SHAPE == "tensor":
153+
OUT_T qvalue = quantize_val(value, t_scale[0], t_zero_point[0]);
146154

147155
t_out[out_bufi] = qvalue;
148156
}

backends/vulkan/runtime/graph/ops/glsl/quantize_buffer.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ quantize_buffer:
33
IN_DTYPE: float
44
OUT_DTYPE: int32
55
MODE: per_tensor
6+
SHAPE: tensor
67
generate_variant_forall:
78
IN_DTYPE:
89
- VALUE: half
@@ -15,6 +16,9 @@ quantize_buffer:
1516
shader_variants:
1617
- NAME: quantize_per_tensor_buffer
1718
MODE: per_tensor
19+
SHAPE: scalar
20+
- NAME: quantize_per_tensor_tensor_buffer
21+
MODE: per_tensor
1822
- NAME: quantize_per_token_buffer
1923
MODE: per_token
2024
- NAME: quantize_per_channel_buffer

backends/vulkan/runtime/graph/ops/glsl/quantize_texture.glsl

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#define IVEC4_T ${texel_load_type(OUT_DTYPE, "texture3d")}
1818

1919
#define ${MODE}
20+
#define ${SHAPE}
2021

2122
${define_active_storage_type("texture3d")}
2223
${define_required_extensions(IN_DTYPE)}
@@ -32,9 +33,14 @@ ${layout_declare_tensor(B, "w", "t_out", OUT_DTYPE, "texture3d")}
3233
${layout_declare_tensor(B, "r", "t_in", IN_DTYPE, "texture3d")}
3334

3435
$if MODE == "per_tensor":
36+
$if SHAPE == "tensor":
37+
${layout_declare_tensor(B, "r", "t_scale", "float", "buffer")}
38+
${layout_declare_tensor(B, "r", "t_zero_point", "int", "buffer")}
39+
3540
layout(push_constant) uniform restrict Block {
36-
float scale;
37-
int zero_point;
41+
$if SHAPE == "scalar":
42+
float scale;
43+
int zero_point;
3844
int quant_min;
3945
int quant_max;
4046
};
@@ -146,7 +152,10 @@ void quantize_per_tensor() {
146152

147153
[[unroll]] for (int i = 0; i < 4; ++i) {
148154
IN_T value = IN_T(intex[i]);
149-
OUT_T qvalue = quantize_val(value, scale, zero_point);
155+
$if SHAPE == "scalar":
156+
OUT_T qvalue = quantize_val(value, scale, zero_point);
157+
$if SHAPE == "tensor":
158+
OUT_T qvalue = quantize_val(value, t_scale[0], t_zero_point[0]);
150159
outtex[i] = qvalue;
151160
}
152161
write_texel(t_out, pos, outtex);

backends/vulkan/runtime/graph/ops/glsl/quantize_texture.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ quantize_texture:
33
IN_DTYPE: float
44
OUT_DTYPE: int32
55
MODE: per_tensor
6+
SHAPE: tensor
67
generate_variant_forall:
78
IN_DTYPE:
89
- VALUE: half
@@ -15,6 +16,9 @@ quantize_texture:
1516
shader_variants:
1617
- NAME: quantize_per_tensor_texture3d
1718
MODE: per_tensor
19+
SHAPE: scalar
20+
- NAME: quantize_per_tensor_tensor_texture3d
21+
MODE: per_tensor
1822
- NAME: quantize_per_token_texture3d
1923
MODE: per_token
2024
- NAME: quantize_per_channel_texture3d

backends/vulkan/runtime/graph/ops/impl/Quantize.cpp

Lines changed: 43 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -51,17 +51,19 @@ utils::uvec3 quantize_per_channel_local_wg_size(
5151

5252
const ValueRef input = args.at(1).refs.at(0);
5353

54-
utils::uvec3 local_wg_size = graph->create_local_wg_size(global_workgroup_size);
55-
56-
// WORKAROUND: The CommandBuffer::dispatch function divides global_workgroup_size
57-
// by local_workgroup_size to get the number of workgroups to dispatch.
58-
// For per-channel quantization along the batch axis, we need to ensure that
59-
// we dispatch the correct number of workgroups in the Z dimension to cover
60-
// all batch-channel combinations.
54+
utils::uvec3 local_wg_size =
55+
graph->create_local_wg_size(global_workgroup_size);
56+
57+
// WORKAROUND: The CommandBuffer::dispatch function divides
58+
// global_workgroup_size by local_workgroup_size to get the number of
59+
// workgroups to dispatch. For per-channel quantization along the batch axis,
60+
// we need to ensure that we dispatch the correct number of workgroups in the
61+
// Z dimension to cover all batch-channel combinations.
6162
//
62-
// If local_wg_size[2] > 1, then div_up(global_workgroup_size[2], local_wg_size[2])
63-
// might reduce the number of workgroups dispatched. To ensure we dispatch
64-
// global_workgroup_size[2] workgroups in the Z dimension, we set local_wg_size[2] = 1.
63+
// If local_wg_size[2] > 1, then div_up(global_workgroup_size[2],
64+
// local_wg_size[2]) might reduce the number of workgroups dispatched. To
65+
// ensure we dispatch global_workgroup_size[2] workgroups in the Z dimension,
66+
// we set local_wg_size[2] = 1.
6567
const auto input_sizes = graph->sizes_of(input);
6668
if (global_workgroup_size[2] > 1 && input_sizes[3] > 0) {
6769
local_wg_size[2] = 1;
@@ -78,13 +80,23 @@ void add_quantize_per_tensor_node(
7880
const ValueRef& quant_min,
7981
const ValueRef& quant_max,
8082
const ValueRef& output) {
83+
const bool is_tensor_scale_zp =
84+
graph.val_is_tensor(scale) && graph.val_is_tensor(zero_point);
85+
8186
std::string kernel_name("quantize_per_tensor");
87+
if (is_tensor_scale_zp) {
88+
kernel_name += "_tensor";
89+
}
8290
add_storage_type_suffix(kernel_name, graph.storage_type_of(input));
8391
add_dtype_suffix(kernel_name, graph.dtype_of(input));
8492
add_dtype_suffix(kernel_name, graph.dtype_of(output));
8593

86-
float scale_val = static_cast<float>(graph.get_double(scale));
87-
int zero_point_val = static_cast<int>(graph.get_int(zero_point));
94+
float scale_val = 1.0;
95+
int zero_point_val = 0;
96+
if (!is_tensor_scale_zp) {
97+
scale_val = static_cast<float>(graph.get_double(scale));
98+
zero_point_val = static_cast<int>(graph.get_int(zero_point));
99+
}
88100
int quant_min_val = static_cast<int>(graph.get_int(quant_min));
89101
int quant_max_val = static_cast<int>(graph.get_int(quant_max));
90102

@@ -98,15 +110,17 @@ void add_quantize_per_tensor_node(
98110
graph.strides_ubo(input),
99111
graph.sizes_ubo(output),
100112
graph.strides_ubo(output)};
113+
} else {
114+
param_ubos = {
115+
graph.logical_limits_ubo(input), graph.logical_limits_ubo(output)};
116+
}
117+
118+
if (is_tensor_scale_zp) {
101119
push_constants = {
102-
PushConstantDataInfo(&scale_val, sizeof(float)),
103-
PushConstantDataInfo(&zero_point_val, sizeof(int)),
104120
PushConstantDataInfo(&quant_min_val, sizeof(int)),
105121
PushConstantDataInfo(&quant_max_val, sizeof(int)),
106122
};
107123
} else {
108-
param_ubos = {
109-
graph.logical_limits_ubo(input), graph.logical_limits_ubo(output)};
110124
push_constants = {
111125
PushConstantDataInfo(&scale_val, sizeof(float)),
112126
PushConstantDataInfo(&zero_point_val, sizeof(int)),
@@ -120,13 +134,20 @@ void add_quantize_per_tensor_node(
120134
graph.hashed_layout_of(input),
121135
};
122136

137+
std::vector<ArgGroup> inputs_and_outputs = {
138+
{output, vkapi::kWrite}, {input, vkapi::kRead}};
139+
if (is_tensor_scale_zp) {
140+
inputs_and_outputs.emplace_back(
141+
ArgGroup{{scale, zero_point}, vkapi::kRead});
142+
}
143+
123144
graph.execute_nodes().emplace_back(new DynamicDispatchNode(
124145
graph,
125146
VK_KERNEL_FROM_STR(kernel_name),
126147
default_pick_global_wg_size,
127148
default_pick_local_wg_size,
128149
// Inputs and Outputs
129-
{{output, vkapi::kWrite}, {input, vkapi::kRead}},
150+
inputs_and_outputs,
130151
// Shader param buffers
131152
param_ubos,
132153
// Push Constants
@@ -241,8 +262,8 @@ void add_quantize_per_channel_node(
241262

242263
int num_channels;
243264
if (axis_val == 0 && ndim == 4 && !graph.is_buffer_storage(input)) {
244-
// For batch dimension quantization in 4D tensors, pass the actual number of channels
245-
// so the shader can correctly unfold the batch-channel folding
265+
// For batch dimension quantization in 4D tensors, pass the actual number of
266+
// channels so the shader can correctly unfold the batch-channel folding
246267
num_channels = static_cast<int>(input_sizes[1]); // Channel dimension
247268
} else {
248269
num_channels = static_cast<int>(input_sizes[axis_val]);
@@ -487,6 +508,9 @@ REGISTER_OPERATORS {
487508
VK_REGISTER_OP(
488509
quantized_decomposed.quantize_per_tensor.default,
489510
quantize_per_tensor_impl);
511+
VK_REGISTER_OP(
512+
quantized_decomposed.quantize_per_tensor.tensor,
513+
quantize_per_tensor_impl);
490514
VK_REGISTER_OP(
491515
quantized_decomposed.quantize_per_token.default, quantize_per_token_impl);
492516
VK_REGISTER_OP(

0 commit comments

Comments
 (0)