Skip to content

Commit c118d45

Browse files
author
morelos
committed
[ET-VK][Ops] quantize_per_tensor.tensor variant
Pull Request resolved: #12208 # Context We need a tensor variant for dequantize/quantize operators since that is the expected output of choose_qparams. # Changes This extends the logic that currently exists to support a tensor variant for scales and zeros. ghstack-source-id: 294625278 @exported-using-ghexport Differential Revision: [D77746136](https://our.internmc.facebook.com/intern/diff/D77746136/)
1 parent cf5f1a7 commit c118d45

File tree

6 files changed

+376
-13
lines changed

6 files changed

+376
-13
lines changed

backends/vulkan/runtime/graph/ops/glsl/quantize_buffer.glsl

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,14 @@ ${layout_declare_tensor(B, "w", "t_out", OUT_DTYPE, "buffer")}
2727
${layout_declare_tensor(B, "r", "t_in", IN_DTYPE, "buffer")}
2828

2929
$if MODE == "per_tensor":
30+
$if SHAPE == "tensor":
31+
${layout_declare_tensor(B, "r", "t_scale", "float", "buffer")}
32+
${layout_declare_tensor(B, "r", "t_zero_point", "int", "buffer")}
33+
3034
layout(push_constant) uniform restrict Block {
31-
float scale;
32-
int zero_point;
35+
$if SHAPE == "scalar":
36+
float scale;
37+
int zero_point;
3338
int quant_min;
3439
int quant_max;
3540
};
@@ -142,7 +147,10 @@ void quantize_per_tensor() {
142147
const int in_bufi = tidx_to_bufi(out_tidx, t_in_strides);
143148

144149
IN_T value = t_in[in_bufi];
145-
OUT_T qvalue = quantize_val(value, scale, zero_point);
150+
$if SHAPE == "scalar":
151+
OUT_T qvalue = quantize_val(value, scale, zero_point);
152+
$if SHAPE == "tensor":
153+
OUT_T qvalue = quantize_val(value, t_scale[0], t_zero_point[0]);
146154

147155
t_out[out_bufi] = qvalue;
148156
}

backends/vulkan/runtime/graph/ops/glsl/quantize_buffer.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ quantize_buffer:
33
IN_DTYPE: float
44
OUT_DTYPE: int32
55
MODE: per_tensor
6+
SHAPE: tensor
67
generate_variant_forall:
78
IN_DTYPE:
89
- VALUE: half
@@ -15,6 +16,9 @@ quantize_buffer:
1516
shader_variants:
1617
- NAME: quantize_per_tensor_buffer
1718
MODE: per_tensor
19+
SHAPE: scalar
20+
- NAME: quantize_per_tensor_tensor_buffer
21+
MODE: per_tensor
1822
- NAME: quantize_per_token_buffer
1923
MODE: per_token
2024
- NAME: quantize_per_channel_buffer

backends/vulkan/runtime/graph/ops/glsl/quantize_texture.glsl

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#define IVEC4_T ${texel_load_type(OUT_DTYPE, "texture3d")}
1818

1919
#define ${MODE}
20+
#define ${SHAPE}
2021

2122
${define_active_storage_type("texture3d")}
2223
${define_required_extensions(IN_DTYPE)}
@@ -32,9 +33,14 @@ ${layout_declare_tensor(B, "w", "t_out", OUT_DTYPE, "texture3d")}
3233
${layout_declare_tensor(B, "r", "t_in", IN_DTYPE, "texture3d")}
3334

3435
$if MODE == "per_tensor":
36+
$if SHAPE == "tensor":
37+
${layout_declare_tensor(B, "r", "t_scale", "float", "buffer")}
38+
${layout_declare_tensor(B, "r", "t_zero_point", "int", "buffer")}
39+
3540
layout(push_constant) uniform restrict Block {
36-
float scale;
37-
int zero_point;
41+
$if SHAPE == "scalar":
42+
float scale;
43+
int zero_point;
3844
int quant_min;
3945
int quant_max;
4046
};
@@ -146,7 +152,10 @@ void quantize_per_tensor() {
146152

147153
[[unroll]] for (int i = 0; i < 4; ++i) {
148154
IN_T value = IN_T(intex[i]);
149-
OUT_T qvalue = quantize_val(value, scale, zero_point);
155+
$if SHAPE == "scalar":
156+
OUT_T qvalue = quantize_val(value, scale, zero_point);
157+
$if SHAPE == "tensor":
158+
OUT_T qvalue = quantize_val(value, t_scale[0], t_zero_point[0]);
150159
outtex[i] = qvalue;
151160
}
152161
write_texel(t_out, pos, outtex);

backends/vulkan/runtime/graph/ops/glsl/quantize_texture.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ quantize_texture:
33
IN_DTYPE: float
44
OUT_DTYPE: int32
55
MODE: per_tensor
6+
SHAPE: tensor
67
generate_variant_forall:
78
IN_DTYPE:
89
- VALUE: half
@@ -15,6 +16,9 @@ quantize_texture:
1516
shader_variants:
1617
- NAME: quantize_per_tensor_texture3d
1718
MODE: per_tensor
19+
SHAPE: scalar
20+
- NAME: quantize_per_tensor_tensor_texture3d
21+
MODE: per_tensor
1822
- NAME: quantize_per_token_texture3d
1923
MODE: per_token
2024
- NAME: quantize_per_channel_texture3d

backends/vulkan/runtime/graph/ops/impl/Quantize.cpp

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -80,13 +80,23 @@ void add_quantize_per_tensor_node(
8080
const ValueRef& quant_min,
8181
const ValueRef& quant_max,
8282
const ValueRef& output) {
83+
const bool is_tensor_scale_zp =
84+
graph.val_is_tensor(scale) && graph.val_is_tensor(zero_point);
85+
8386
std::string kernel_name("quantize_per_tensor");
87+
if (is_tensor_scale_zp) {
88+
kernel_name += "_tensor";
89+
}
8490
add_storage_type_suffix(kernel_name, graph.storage_type_of(input));
8591
add_dtype_suffix(kernel_name, graph.dtype_of(input));
8692
add_dtype_suffix(kernel_name, graph.dtype_of(output));
8793

88-
float scale_val = static_cast<float>(graph.get_double(scale));
89-
int zero_point_val = static_cast<int>(graph.get_int(zero_point));
94+
float scale_val = 1.0;
95+
int zero_point_val = 0;
96+
if (!is_tensor_scale_zp) {
97+
scale_val = static_cast<float>(graph.get_double(scale));
98+
zero_point_val = static_cast<int>(graph.get_int(zero_point));
99+
}
90100
int quant_min_val = static_cast<int>(graph.get_int(quant_min));
91101
int quant_max_val = static_cast<int>(graph.get_int(quant_max));
92102

@@ -100,15 +110,17 @@ void add_quantize_per_tensor_node(
100110
graph.strides_ubo(input),
101111
graph.sizes_ubo(output),
102112
graph.strides_ubo(output)};
113+
} else {
114+
param_ubos = {
115+
graph.logical_limits_ubo(input), graph.logical_limits_ubo(output)};
116+
}
117+
118+
if (is_tensor_scale_zp) {
103119
push_constants = {
104-
PushConstantDataInfo(&scale_val, sizeof(float)),
105-
PushConstantDataInfo(&zero_point_val, sizeof(int)),
106120
PushConstantDataInfo(&quant_min_val, sizeof(int)),
107121
PushConstantDataInfo(&quant_max_val, sizeof(int)),
108122
};
109123
} else {
110-
param_ubos = {
111-
graph.logical_limits_ubo(input), graph.logical_limits_ubo(output)};
112124
push_constants = {
113125
PushConstantDataInfo(&scale_val, sizeof(float)),
114126
PushConstantDataInfo(&zero_point_val, sizeof(int)),
@@ -122,13 +134,20 @@ void add_quantize_per_tensor_node(
122134
graph.hashed_layout_of(input),
123135
};
124136

137+
std::vector<ArgGroup> inputs_and_outputs = {
138+
{output, vkapi::kWrite}, {input, vkapi::kRead}};
139+
if (is_tensor_scale_zp) {
140+
inputs_and_outputs.emplace_back(
141+
ArgGroup{{scale, zero_point}, vkapi::kRead});
142+
}
143+
125144
graph.execute_nodes().emplace_back(new DynamicDispatchNode(
126145
graph,
127146
VK_KERNEL_FROM_STR(kernel_name),
128147
default_pick_global_wg_size,
129148
default_pick_local_wg_size,
130149
// Inputs and Outputs
131-
{{output, vkapi::kWrite}, {input, vkapi::kRead}},
150+
inputs_and_outputs,
132151
// Shader param buffers
133152
param_ubos,
134153
// Push Constants
@@ -489,6 +508,9 @@ REGISTER_OPERATORS {
489508
VK_REGISTER_OP(
490509
quantized_decomposed.quantize_per_tensor.default,
491510
quantize_per_tensor_impl);
511+
VK_REGISTER_OP(
512+
quantized_decomposed.quantize_per_tensor.tensor,
513+
quantize_per_tensor_impl);
492514
VK_REGISTER_OP(
493515
quantized_decomposed.quantize_per_token.default, quantize_per_token_impl);
494516
VK_REGISTER_OP(

0 commit comments

Comments
 (0)