Skip to content

Commit 10512c5

Browse files
author
morelos
committed
[ET-VK][Ops] dequantize_per_tensor.tensor variant
# Context We need a tensor variant for dequantize/quantize operators since that is the expected output of choose_qparams. # Changes This extends the logic that currently exists to support a tensor variant for scales and zeros. Differential Revision: [D77746135](https://our.internmc.facebook.com/intern/diff/D77746135/) [ghstack-poisoned]
1 parent 6ecf0d1 commit 10512c5

File tree

6 files changed

+416
-13
lines changed

6 files changed

+416
-13
lines changed

backends/vulkan/runtime/graph/ops/glsl/dequantize_buffer.glsl

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,14 @@ ${layout_declare_tensor(B, "w", "t_out", OUT_DTYPE, "buffer")}
2727
${layout_declare_tensor(B, "r", "t_in", IN_DTYPE, "buffer")}
2828

2929
$if MODE == "per_tensor":
30+
$if SHAPE == "tensor":
31+
${layout_declare_tensor(B, "r", "t_scale", "float", "buffer")}
32+
${layout_declare_tensor(B, "r", "t_zero_point", "int", "buffer")}
33+
3034
layout(push_constant) uniform restrict Block {
31-
float scale;
32-
int zero_point;
35+
$if SHAPE == "scalar":
36+
float scale;
37+
int zero_point;
3338
int quant_min;
3439
int quant_max;
3540
};
@@ -146,7 +151,10 @@ void dequantize_per_tensor() {
146151
const int in_bufi = tidx_to_bufi(out_tidx, t_in_strides);
147152

148153
IN_T qvalue = t_in[in_bufi];
149-
OUT_T value = dequantize_val(qvalue, scale, zero_point);
154+
$if SHAPE == "scalar":
155+
OUT_T value = dequantize_val(qvalue, scale, zero_point);
156+
$if SHAPE == "tensor":
157+
OUT_T value = dequantize_val(qvalue, t_scale[0], t_zero_point[0]);
150158

151159
t_out[out_bufi] = value;
152160
}

backends/vulkan/runtime/graph/ops/glsl/dequantize_buffer.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ dequantize_buffer:
33
IN_DTYPE: int32
44
OUT_DTYPE: float
55
MODE: per_tensor
6+
SHAPE: tensor
67
generate_variant_forall:
78
IN_DTYPE:
89
- VALUE: uint8
@@ -15,6 +16,9 @@ dequantize_buffer:
1516
shader_variants:
1617
- NAME: dequantize_per_tensor_buffer
1718
MODE: per_tensor
19+
SHAPE: scalar
20+
- NAME: dequantize_per_tensor_tensor_buffer
21+
MODE: per_tensor
1822
- NAME: dequantize_per_token_buffer
1923
MODE: per_token
2024
- NAME: dequantize_per_channel_buffer

backends/vulkan/runtime/graph/ops/glsl/dequantize_texture.glsl

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,14 @@ ${layout_declare_tensor(B, "w", "t_out", OUT_DTYPE, "texture3d")}
3030
${layout_declare_tensor(B, "r", "t_in", IN_DTYPE, "texture3d")}
3131

3232
$if MODE == "per_tensor":
33+
$if SHAPE == "tensor":
34+
${layout_declare_tensor(B, "r", "t_scale", "float", "buffer")}
35+
${layout_declare_tensor(B, "r", "t_zero_point", "int", "buffer")}
36+
3337
layout(push_constant) uniform restrict Block {
34-
float scale;
35-
int zero_point;
38+
$if SHAPE == "scalar":
39+
float scale;
40+
int zero_point;
3641
int quant_min;
3742
int quant_max;
3843
};
@@ -148,7 +153,11 @@ void dequantize_per_tensor() {
148153

149154
[[unroll]] for (int i = 0; i < 4; ++i) {
150155
IN_T qvalue = IN_T(intex[i]);
151-
OUT_T value = dequantize_val(qvalue, scale, zero_point);
156+
$if SHAPE == "scalar":
157+
OUT_T value = dequantize_val(qvalue, scale, zero_point);
158+
$if SHAPE == "tensor":
159+
OUT_T value = dequantize_val(qvalue, t_scale[0], t_zero_point[0]);
160+
152161
$if OUT_DTYPE == "double":
153162
outtex[i] = float(value);
154163
$else:

backends/vulkan/runtime/graph/ops/glsl/dequantize_texture.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ dequantize_texture:
33
IN_DTYPE: int32
44
OUT_DTYPE: float
55
MODE: per_tensor
6+
SHAPE: tensor
67
generate_variant_forall:
78
IN_DTYPE:
89
- VALUE: uint8
@@ -15,6 +16,9 @@ dequantize_texture:
1516
shader_variants:
1617
- NAME: dequantize_per_tensor_texture3d
1718
MODE: per_tensor
19+
SHAPE: scalar
20+
- NAME: dequantize_per_tensor_tensor_texture3d
21+
MODE: per_tensor
1822
- NAME: dequantize_per_token_texture3d
1923
MODE: per_token
2024
- NAME: dequantize_per_channel_texture3d

backends/vulkan/runtime/graph/ops/impl/Dequantize.cpp

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -78,13 +78,23 @@ void add_dequantize_per_tensor_node(
7878
const ValueRef& quant_min,
7979
const ValueRef& quant_max,
8080
const ValueRef& output) {
81+
const bool is_tensor_scale_zp =
82+
graph.val_is_tensor(scale) && graph.val_is_tensor(zero_point);
83+
8184
std::string kernel_name("dequantize_per_tensor");
85+
if (is_tensor_scale_zp) {
86+
kernel_name += "_tensor";
87+
}
8288
add_storage_type_suffix(kernel_name, graph.storage_type_of(input));
8389
add_dtype_suffix(kernel_name, graph.dtype_of(input));
8490
add_dtype_suffix(kernel_name, graph.dtype_of(output));
8591

86-
float scale_val = static_cast<float>(graph.get_double(scale));
87-
int zero_point_val = static_cast<int>(graph.get_int(zero_point));
92+
float scale_val = 1.0;
93+
int zero_point_val = 0;
94+
if (!is_tensor_scale_zp) {
95+
scale_val = static_cast<float>(graph.get_double(scale));
96+
zero_point_val = static_cast<int>(graph.get_int(zero_point));
97+
}
8898
int quant_min_val = static_cast<int>(graph.get_int(quant_min));
8999
int quant_max_val = static_cast<int>(graph.get_int(quant_max));
90100

@@ -98,15 +108,17 @@ void add_dequantize_per_tensor_node(
98108
graph.strides_ubo(input),
99109
graph.sizes_ubo(output),
100110
graph.strides_ubo(output)};
111+
} else {
112+
param_ubos = {
113+
graph.logical_limits_ubo(input), graph.logical_limits_ubo(output)};
114+
}
115+
116+
if (is_tensor_scale_zp) {
101117
push_constants = {
102-
PushConstantDataInfo(&scale_val, sizeof(float)),
103-
PushConstantDataInfo(&zero_point_val, sizeof(int)),
104118
PushConstantDataInfo(&quant_min_val, sizeof(int)),
105119
PushConstantDataInfo(&quant_max_val, sizeof(int)),
106120
};
107121
} else {
108-
param_ubos = {
109-
graph.logical_limits_ubo(input), graph.logical_limits_ubo(output)};
110122
push_constants = {
111123
PushConstantDataInfo(&scale_val, sizeof(float)),
112124
PushConstantDataInfo(&zero_point_val, sizeof(int)),
@@ -120,13 +132,20 @@ void add_dequantize_per_tensor_node(
120132
graph.hashed_layout_of(input),
121133
};
122134

135+
std::vector<ArgGroup> inputs_and_outputs = {
136+
{output, vkapi::kWrite}, {input, vkapi::kRead}};
137+
if (is_tensor_scale_zp) {
138+
inputs_and_outputs.emplace_back(
139+
ArgGroup{{scale, zero_point}, vkapi::kRead});
140+
}
141+
123142
graph.execute_nodes().emplace_back(new DynamicDispatchNode(
124143
graph,
125144
VK_KERNEL_FROM_STR(kernel_name),
126145
default_pick_global_wg_size,
127146
default_pick_local_wg_size,
128147
// Inputs and Outputs
129-
{{output, vkapi::kWrite}, {input, vkapi::kRead}},
148+
inputs_and_outputs,
130149
// Shader param buffers
131150
param_ubos,
132151
// Push Constants
@@ -517,6 +536,9 @@ REGISTER_OPERATORS {
517536
VK_REGISTER_OP(
518537
quantized_decomposed.dequantize_per_tensor.default,
519538
dequantize_per_tensor_impl);
539+
VK_REGISTER_OP(
540+
quantized_decomposed.dequantize_per_tensor.tensor,
541+
dequantize_per_tensor_impl);
520542
VK_REGISTER_OP(
521543
quantized_decomposed.dequantize_per_token.default,
522544
dequantize_per_token_impl);

0 commit comments

Comments
 (0)