Skip to content

Commit cf700d9

Browse files
committed
Fix dequantize per channel to handle double scale type
Pull Request resolved: pytorch/executorch#5524 ghstack-source-id: 244449200 @exported-using-ghexport Differential Revision: [D62301839](https://our.internmc.facebook.com/intern/diff/D62301839/)
1 parent 1c793ab commit cf700d9

File tree

2 files changed

+14
-9
lines changed

2 files changed

+14
-9
lines changed

kernels/quantized/cpu/op_dequantize.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -196,8 +196,8 @@ Tensor& dequantize_per_channel_out(
196196
"Failed to resize out Tensor in dequantize_per_channel_out");
197197

198198
ET_CHECK_MSG(
199-
scale.scalar_type() == ScalarType::Float,
200-
"scale.scalar_type() %" PRId8 " is not float type",
199+
scale.scalar_type() == ScalarType::Double,
200+
"scale.scalar_type() %" PRId8 " is not double type",
201201
static_cast<int8_t>(scale.scalar_type()));
202202

203203
ET_CHECK_MSG(
@@ -232,7 +232,7 @@ Tensor& dequantize_per_channel_out(
232232
dims[i] = i + 1;
233233
}
234234
}
235-
const float* scale_data = scale.const_data_ptr<float>();
235+
const double* scale_data = scale.const_data_ptr<double>();
236236
const int64_t* zero_point_data;
237237
if (opt_zero_points.has_value()) {
238238
zero_point_data = opt_zero_points.value().const_data_ptr<int64_t>();
@@ -264,7 +264,7 @@ Tensor& dequantize_per_channel_out(
264264
size_t numel, size_t stride, size_t base_ix) { \
265265
for (size_t i = 0; i < numel; i++) { \
266266
size_t current_ix = base_ix * stride + i; \
267-
float _scale = scale_data[current_ix]; \
267+
float _scale = static_cast<float>(scale_data[current_ix]); \
268268
int64_t zero_point = 0; \
269269
if (zero_point_data != nullptr) { \
270270
zero_point = zero_point_data[current_ix]; \
@@ -280,7 +280,7 @@ Tensor& dequantize_per_channel_out(
280280
break; \
281281
} \
282282
for (size_t channel_ix = 0; channel_ix < input.size(axis); ++channel_ix) { \
283-
float _scale = scale_data[channel_ix]; \
283+
float _scale = static_cast<float>(scale_data[channel_ix]); \
284284
int64_t _zero_point = 0; \
285285
if (zero_point_data != nullptr) { \
286286
_zero_point = zero_point_data[channel_ix]; \

kernels/quantized/test/op_dequantize_test.cpp

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
1212
#include <executorch/runtime/core/exec_aten/testing_util/tensor_util.h>
1313
#include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>
14+
#include <executorch/runtime/platform/runtime.h>
1415
#include <executorch/test/utils/DeathTest.h>
1516

1617
#include <gtest/gtest.h>
@@ -57,10 +58,12 @@ void test_dtype() {
5758
}
5859

5960
TEST(OpDequantizeOutTest, AllDtypesSupported) {
61+
et_pal_init();
6062
test_dtype<ScalarType::Byte>();
6163
}
6264

6365
TEST(OpDequantizeOutTest, NonWholeNumbers) {
66+
et_pal_init();
6467
TensorFactory<ScalarType::Byte> tf;
6568

6669
Tensor input = tf.full({3, 5}, 100);
@@ -87,6 +90,7 @@ TEST(OpDequantizeOutTest, NonWholeNumbers) {
8790
}
8891

8992
TEST(OpDequantizeOutTest, TensorArgOverload) {
93+
et_pal_init();
9094
TensorFactory<ScalarType::Byte> tf_byte;
9195
TensorFactory<ScalarType::Double> tf_double;
9296
TensorFactory<ScalarType::Long> tf_long;
@@ -115,12 +119,13 @@ TEST(OpDequantizeOutTest, TensorArgOverload) {
115119
}
116120

117121
TEST(OpDequantizeOutTest, DequantizePerChannel) {
122+
et_pal_init();
118123
TensorFactory<ScalarType::Byte> tf_byte;
119-
TensorFactory<ScalarType::Float> tf_float;
124+
TensorFactory<ScalarType::Double> tf_double;
120125
TensorFactory<ScalarType::Long> tf_long;
121126

122127
Tensor input = tf_byte.full({3, 2}, 100);
123-
Tensor scale = tf_float.make({2}, {0.5, 1});
128+
Tensor scale = tf_double.make({2}, {0.5, 1});
124129
Tensor zero_point = tf_long.make({2}, {30, 60});
125130
int64_t quant_min = 0;
126131
int64_t quant_max = 255;
@@ -145,7 +150,7 @@ TEST(OpDequantizeOutTest, DequantizePerChannel) {
145150

146151
// Test with a different axis
147152
out = tfo.zeros({3, 2});
148-
scale = tf_float.make({3}, {0.5, 0.75, 1});
153+
scale = tf_double.make({3}, {0.5, 0.75, 1});
149154
zero_point = tf_long.make({3}, {30, 50, 60});
150155
// (100 - 30) * 0.5
151156
// (100 - 50) * 0.75
@@ -167,7 +172,7 @@ TEST(OpDequantizeOutTest, DequantizePerChannel) {
167172
// Test with a different axis
168173
out = tfo.zeros({3});
169174
input = tf_byte.make({3}, {100, 100, 100});
170-
scale = tf_float.make({3}, {0.5, 0.75, 1});
175+
scale = tf_double.make({3}, {0.5, 0.75, 1});
171176
zero_point = tf_long.make({3}, {30, 50, 60});
172177
// (100 - 30) * 0.5
173178
// (100 - 50) * 0.75

0 commit comments

Comments
 (0)