Skip to content

Commit a08c8b7

Browse files
larryliu0820facebook-github-bot
authored andcommitted
Add dequantize_per_channel.out kernel (#1375)
Summary: Pull Request resolved: #1375 As titled Reviewed By: jerryzh168 Differential Revision: D51957534 fbshipit-source-id: 9962ec50862be9bdb24872c740bd6b9aea80121a
1 parent cff7a97 commit a08c8b7

File tree

4 files changed

+194
-7
lines changed

4 files changed

+194
-7
lines changed

kernels/quantized/cpu/op_dequantize.cpp

Lines changed: 132 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
* LICENSE file in the root directory of this source tree.
77
*/
88

9+
#include <executorch/kernels/portable/cpu/util/reduce_util.h>
910
#include <executorch/runtime/kernel/kernel_includes.h>
1011
#include <algorithm>
1112
#include <cinttypes>
@@ -29,8 +30,6 @@ namespace {
2930
*/
3031
void check_dequantize_per_tensor_args(
3132
const Tensor& input,
32-
double scale,
33-
int64_t zero_point,
3433
int64_t quant_min,
3534
int64_t quant_max,
3635
ScalarType dtype,
@@ -58,9 +57,6 @@ void check_dequantize_per_tensor_args(
5857
"quant min: %" PRId64 " is greater than quant max: %" PRId64,
5958
quant_min,
6059
quant_max);
61-
62-
(void)scale;
63-
(void)zero_point;
6460
}
6561

6662
} // namespace
@@ -87,8 +83,7 @@ Tensor& dequantize_per_tensor_out(
8783
err == torch::executor::Error::Ok,
8884
"Failed to resize out Tensor in dequantize_per_tensor_out");
8985

90-
check_dequantize_per_tensor_args(
91-
input, scale, zero_point, quant_min, quant_max, dtype, out);
86+
check_dequantize_per_tensor_args(input, quant_min, quant_max, dtype, out);
9287

9388
// calculate the dequantized output, cast scale to float to match fbgemm
9489
// behavior
@@ -162,6 +157,136 @@ Tensor& dequantize_per_tensor_tensor_args_out(
162157
return out;
163158
}
164159

160+
Tensor& dequantize_per_channel_out(
161+
const Tensor& input,
162+
const Tensor& scale,
163+
const Tensor& zero_point,
164+
int64_t axis,
165+
int64_t quant_min,
166+
int64_t quant_max,
167+
ScalarType dtype,
168+
Tensor& out) {
169+
torch::executor::Error err = resize_tensor(out, input.sizes());
170+
171+
// normalize axis
172+
ET_CHECK_MSG(
173+
tensor_has_dim(input, axis),
174+
"axis %zd is not legal it should be -input.dim() <= axis < input.dim() %zd",
175+
ssize_t(axis),
176+
ssize_t(input.dim()));
177+
178+
if (axis < 0) {
179+
axis += nonzero_dim(input);
180+
}
181+
182+
ET_CHECK_MSG(
183+
err == torch::executor::Error::Ok,
184+
"Failed to resize out Tensor in dequantize_per_channel_out");
185+
186+
ET_CHECK_MSG(
187+
scale.scalar_type() == ScalarType::Double,
188+
"scale.scalar_type() %" PRId8 " is not double type",
189+
static_cast<int8_t>(scale.scalar_type()));
190+
191+
ET_CHECK_MSG(
192+
scale.numel() == input.size(axis),
193+
"scale.numel() %zd != input.size(axis) %zd",
194+
ssize_t(scale.numel()),
195+
ssize_t(input.size(axis)));
196+
197+
ET_CHECK_MSG(
198+
zero_point.scalar_type() == ScalarType::Long,
199+
"zero_point.scalar_type() %" PRId8 " is not integer type",
200+
static_cast<int8_t>(zero_point.scalar_type()));
201+
202+
ET_CHECK_MSG(
203+
zero_point.numel() == input.size(axis),
204+
"zero_point.numel() %zd != input.size(axis) %zd",
205+
ssize_t(zero_point.numel()),
206+
ssize_t(input.size(axis)));
207+
208+
check_dequantize_per_tensor_args(input, quant_min, quant_max, dtype, out);
209+
210+
// a list contains all dimensions except axis
211+
int64_t dims[input.dim() - 1];
212+
for (int64_t i = 0; i < input.dim() - 1; i++) {
213+
if (i < axis) {
214+
dims[i] = i;
215+
} else {
216+
dims[i] = i - 1;
217+
}
218+
}
219+
const double* scale_data = scale.const_data_ptr<double>();
220+
const int64_t* zero_point_data = zero_point.const_data_ptr<int64_t>();
221+
222+
exec_aten::optional<exec_aten::ArrayRef<int64_t>> optional_dim_list{
223+
exec_aten::ArrayRef<int64_t>{dims, size_t(input.dim() - 1)}};
224+
225+
// Actual dequantization logic
226+
// input, out are the input and output tensors
227+
// channel_ix is the index along the axis dimension. 0 <= channel_ix <
228+
// input.size(axis).
229+
// i.e. if the tensor has shape (N,C,H,W), axis being 1, then channel_ix
230+
// will be 0, 1, 2, ... C-1
231+
// in_ix is the flat index of the element you are dequantizing.
232+
// in other words you are dequantizing in_data[in_ix]
233+
#define DEQUANTIZE_IMPL(CTYPE_IN, CTYPE_OUT, out_dtype) \
234+
case ScalarType::out_dtype: \
235+
for (size_t channel_ix = 0; channel_ix < input.size(axis); ++channel_ix) { \
236+
double _scale = scale_data[channel_ix]; \
237+
int64_t _zero_point = zero_point_data[channel_ix]; \
238+
apply_over_dim_list( \
239+
[input, out, _scale, _zero_point](size_t in_ix) { \
240+
out.mutable_data_ptr<CTYPE_OUT>()[in_ix] = static_cast<CTYPE_OUT>( \
241+
(input.const_data_ptr<CTYPE_IN>()[in_ix] - _zero_point) * \
242+
_scale); \
243+
}, \
244+
input, \
245+
optional_dim_list, \
246+
channel_ix); \
247+
} \
248+
break;
249+
#define CALCULATE_FLOAT_TYPE(CTYPE_IN, in_dtype) \
250+
case ScalarType::in_dtype: \
251+
switch (out.scalar_type()) { \
252+
ET_FORALL_FLOAT_TYPES_WITH(CTYPE_IN, DEQUANTIZE_IMPL); \
253+
default: \
254+
ET_CHECK_MSG( \
255+
false, \
256+
"Unhandled output dtype %" PRId8, \
257+
static_cast<int8_t>(out.scalar_type())); \
258+
} \
259+
break;
260+
261+
switch (input.scalar_type()) {
262+
ET_FORALL_INT_TYPES(CALCULATE_FLOAT_TYPE);
263+
default:
264+
ET_CHECK_MSG(
265+
false,
266+
"Unhandled input dtype %" PRId8,
267+
static_cast<int8_t>(input.scalar_type()));
268+
}
269+
#undef CALCULATE_FLOAT_TYPE
270+
#undef QUANTIZE_IMPL
271+
272+
return out;
273+
}
274+
275+
Tensor& dequantize_per_channel_out(
276+
RuntimeContext& context,
277+
const Tensor& input,
278+
const Tensor& scale,
279+
const Tensor& zero_point,
280+
int64_t axis,
281+
int64_t quant_min,
282+
int64_t quant_max,
283+
ScalarType dtype,
284+
Tensor& out) {
285+
(void)context;
286+
return dequantize_per_channel_out(
287+
input, scale, zero_point, axis, quant_min, quant_max, dtype, out);
288+
}
289+
165290
Tensor& dequantize_per_tensor_out(
166291
RuntimeContext& context,
167292
const Tensor& input,

kernels/quantized/cpu/targets.bzl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,12 @@ _QUANT_OPS = (
1313
),
1414
op_target(
1515
name = "op_dequantize",
16+
deps = [
17+
"//executorch/kernels/portable/cpu/util:reduce_util",
18+
],
19+
_aten_mode_deps = [
20+
"//executorch/kernels/portable/cpu/util:reduce_util_aten",
21+
],
1622
),
1723
op_target(
1824
name = "op_embedding",

kernels/quantized/quantized.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,12 @@
2828
- arg_meta: null
2929
kernel_name: torch::executor::quantize_per_channel_out
3030

31+
- func: dequantize_per_channel.out(Tensor input, Tensor scales, Tensor zero_points, int axis, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!)
32+
variants: function
33+
kernels:
34+
- arg_meta: null
35+
kernel_name: torch::executor::dequantize_per_channel_out
36+
3137
- func: quantized_decomposed::embedding_byte.out(Tensor weight, Tensor weight_scales, Tensor weight_zero_points, int weight_quant_min, int weight_quant_max, Tensor indices, *, Tensor(a!) out) -> Tensor(a!)
3238
variants: function
3339
kernels:

kernels/quantized/test/op_dequantize_test.cpp

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ using exec_aten::ArrayRef;
2121
using exec_aten::Scalar;
2222
using exec_aten::ScalarType;
2323
using exec_aten::Tensor;
24+
using torch::executor::native::dequantize_per_channel_out;
2425
using torch::executor::native::dequantize_per_tensor_out;
2526
using torch::executor::native::dequantize_per_tensor_tensor_args_out;
2627
using torch::executor::testing::TensorFactory;
@@ -90,3 +91,52 @@ TEST(OpDequantizeOutTest, TensorArgOverload) {
9091

9192
EXPECT_TENSOR_EQ(out, expected);
9293
}
94+
95+
TEST(OpDequantizeOutTest, DequantizePerChannel) {
96+
TensorFactory<ScalarType::Byte> tf_byte;
97+
TensorFactory<ScalarType::Double> tf_double;
98+
TensorFactory<ScalarType::Long> tf_long;
99+
100+
Tensor input = tf_byte.full({3, 2}, 100);
101+
Tensor scale = tf_double.make({2}, {0.5, 1});
102+
Tensor zero_point = tf_long.make({2}, {30, 60});
103+
int64_t quant_min = 0;
104+
int64_t quant_max = 255;
105+
106+
TensorFactory<ScalarType::Float> tfo;
107+
Tensor out = tfo.zeros({3, 2});
108+
// (100 - 30) * 0.5
109+
// (100 - 60) * 1
110+
Tensor expected = tfo.make({3, 2}, {35, 40, 35, 40, 35, 40});
111+
dequantize_per_channel_out(
112+
input,
113+
scale,
114+
zero_point,
115+
/*axis=*/1,
116+
quant_min,
117+
quant_max,
118+
ScalarType::Byte,
119+
out);
120+
121+
EXPECT_TENSOR_EQ(out, expected);
122+
123+
// Test with a different axis
124+
out = tfo.zeros({3, 2});
125+
scale = tf_double.make({3}, {0.5, 0.75, 1});
126+
zero_point = tf_long.make({3}, {30, 50, 60});
127+
// (100 - 30) * 0.5
128+
// (100 - 50) * 0.75
129+
// (100 - 60) * 1
130+
expected = tfo.make({3, 2}, {35, 35, 37.5, 37.5, 40, 40});
131+
dequantize_per_channel_out(
132+
input,
133+
scale,
134+
zero_point,
135+
/*axis=*/0,
136+
quant_min,
137+
quant_max,
138+
ScalarType::Byte,
139+
out);
140+
141+
EXPECT_TENSOR_EQ(out, expected);
142+
}

0 commit comments

Comments
 (0)