Skip to content

Commit 9d224a5

Browse files
kimishpatelfacebook-github-bot
authored andcommitted
Fix dequantize per channel to handle double scale type (#5524)
Summary: Pull Request resolved: #5524 ghstack-source-id: 244685036 exported-using-ghexport Reviewed By: swolchok Differential Revision: D62301839 fbshipit-source-id: ac969b80fda97adacef0ad6afab3bc0cf34050b0
1 parent 985f92d commit 9d224a5

File tree

2 files changed

+25
-13
lines changed

2 files changed

+25
-13
lines changed

kernels/quantized/cpu/op_dequantize.cpp

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,19 @@ Tensor& dequantize_per_tensor_tensor_args_out(
168168
return out;
169169
}
170170

171+
float get_scale(const Tensor& scale, size_t channel_ix) {
172+
ET_CHECK_MSG(
173+
(scale.scalar_type() == ScalarType::Double) ||
174+
(scale.scalar_type() == ScalarType::Float),
175+
"scale.scalar_type() %" PRId8 " is not double or float type",
176+
static_cast<int8_t>(scale.scalar_type()));
177+
if (scale.scalar_type() == ScalarType::Double) {
178+
return static_cast<float>(scale.const_data_ptr<double>()[channel_ix]);
179+
} else {
180+
return scale.const_data_ptr<float>()[channel_ix];
181+
}
182+
}
183+
171184
Tensor& dequantize_per_channel_out(
172185
const Tensor& input,
173186
const Tensor& scale,
@@ -195,11 +208,6 @@ Tensor& dequantize_per_channel_out(
195208
err == torch::executor::Error::Ok,
196209
"Failed to resize out Tensor in dequantize_per_channel_out");
197210

198-
ET_CHECK_MSG(
199-
scale.scalar_type() == ScalarType::Float,
200-
"scale.scalar_type() %" PRId8 " is not float type",
201-
static_cast<int8_t>(scale.scalar_type()));
202-
203211
ET_CHECK_MSG(
204212
scale.numel() == input.size(axis),
205213
"scale.numel() %zd != input.size(axis) %zd",
@@ -232,7 +240,6 @@ Tensor& dequantize_per_channel_out(
232240
dims[i] = i + 1;
233241
}
234242
}
235-
const float* scale_data = scale.const_data_ptr<float>();
236243
const int64_t* zero_point_data;
237244
if (opt_zero_points.has_value()) {
238245
zero_point_data = opt_zero_points.value().const_data_ptr<int64_t>();
@@ -260,11 +267,11 @@ Tensor& dequantize_per_channel_out(
260267
axis == 0, "Axis must be 0 for a single dimensional tensors"); \
261268
const optional<int64_t> dim; \
262269
apply_over_dim( \
263-
[input_data_ptr, out_data_ptr, scale_data, zero_point_data]( \
270+
[input_data_ptr, out_data_ptr, zero_point_data, &scale]( \
264271
size_t numel, size_t stride, size_t base_ix) { \
265272
for (size_t i = 0; i < numel; i++) { \
266273
size_t current_ix = base_ix * stride + i; \
267-
float _scale = scale_data[current_ix]; \
274+
float _scale = get_scale(scale, current_ix); \
268275
int64_t zero_point = 0; \
269276
if (zero_point_data != nullptr) { \
270277
zero_point = zero_point_data[current_ix]; \
@@ -280,7 +287,7 @@ Tensor& dequantize_per_channel_out(
280287
break; \
281288
} \
282289
for (size_t channel_ix = 0; channel_ix < input.size(axis); ++channel_ix) { \
283-
float _scale = scale_data[channel_ix]; \
290+
float _scale = get_scale(scale, channel_ix); \
284291
int64_t _zero_point = 0; \
285292
if (zero_point_data != nullptr) { \
286293
_zero_point = zero_point_data[channel_ix]; \

kernels/quantized/test/op_dequantize_test.cpp

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
1212
#include <executorch/runtime/core/exec_aten/testing_util/tensor_util.h>
1313
#include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>
14+
#include <executorch/runtime/platform/runtime.h>
1415
#include <executorch/test/utils/DeathTest.h>
1516

1617
#include <gtest/gtest.h>
@@ -57,10 +58,12 @@ void test_dtype() {
5758
}
5859

5960
TEST(OpDequantizeOutTest, AllDtypesSupported) {
61+
et_pal_init();
6062
test_dtype<ScalarType::Byte>();
6163
}
6264

6365
TEST(OpDequantizeOutTest, NonWholeNumbers) {
66+
et_pal_init();
6467
TensorFactory<ScalarType::Byte> tf;
6568

6669
Tensor input = tf.full({3, 5}, 100);
@@ -87,6 +90,7 @@ TEST(OpDequantizeOutTest, NonWholeNumbers) {
8790
}
8891

8992
TEST(OpDequantizeOutTest, TensorArgOverload) {
93+
et_pal_init();
9094
TensorFactory<ScalarType::Byte> tf_byte;
9195
TensorFactory<ScalarType::Double> tf_double;
9296
TensorFactory<ScalarType::Long> tf_long;
@@ -115,12 +119,13 @@ TEST(OpDequantizeOutTest, TensorArgOverload) {
115119
}
116120

117121
TEST(OpDequantizeOutTest, DequantizePerChannel) {
122+
et_pal_init();
118123
TensorFactory<ScalarType::Byte> tf_byte;
119-
TensorFactory<ScalarType::Float> tf_float;
124+
TensorFactory<ScalarType::Double> tf_double;
120125
TensorFactory<ScalarType::Long> tf_long;
121126

122127
Tensor input = tf_byte.full({3, 2}, 100);
123-
Tensor scale = tf_float.make({2}, {0.5, 1});
128+
Tensor scale = tf_double.make({2}, {0.5, 1});
124129
Tensor zero_point = tf_long.make({2}, {30, 60});
125130
int64_t quant_min = 0;
126131
int64_t quant_max = 255;
@@ -145,7 +150,7 @@ TEST(OpDequantizeOutTest, DequantizePerChannel) {
145150

146151
// Test with a different axis
147152
out = tfo.zeros({3, 2});
148-
scale = tf_float.make({3}, {0.5, 0.75, 1});
153+
scale = tf_double.make({3}, {0.5, 0.75, 1});
149154
zero_point = tf_long.make({3}, {30, 50, 60});
150155
// (100 - 30) * 0.5
151156
// (100 - 50) * 0.75
@@ -167,7 +172,7 @@ TEST(OpDequantizeOutTest, DequantizePerChannel) {
167172
// Test with a different axis
168173
out = tfo.zeros({3});
169174
input = tf_byte.make({3}, {100, 100, 100});
170-
scale = tf_float.make({3}, {0.5, 0.75, 1});
175+
scale = tf_double.make({3}, {0.5, 0.75, 1});
171176
zero_point = tf_long.make({3}, {30, 50, 60});
172177
// (100 - 30) * 0.5
173178
// (100 - 50) * 0.75

0 commit comments

Comments
 (0)