Skip to content

Commit 3fbc178

Browse files
committed
Fix dequantize per channel to handle double scale type
Differential Revision: [D62301839](https://our.internmc.facebook.com/intern/diff/D62301839/) ghstack-source-id: 243859225 Pull Request resolved: #5524
1 parent 10a1d5f commit 3fbc178

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

kernels/quantized/cpu/op_dequantize.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,7 @@ Tensor& dequantize_per_channel_out(
196196
"Failed to resize out Tensor in dequantize_per_channel_out");
197197

198198
ET_CHECK_MSG(
199-
scale.scalar_type() == ScalarType::Float,
199+
scale.scalar_type() == ScalarType::Double,
200200
"scale.scalar_type() %" PRId8 " is not float type",
201201
static_cast<int8_t>(scale.scalar_type()));
202202

@@ -232,7 +232,7 @@ Tensor& dequantize_per_channel_out(
232232
dims[i] = i + 1;
233233
}
234234
}
235-
const float* scale_data = scale.const_data_ptr<float>();
235+
const double* scale_data = scale.const_data_ptr<double>();
236236
const int64_t* zero_point_data;
237237
if (opt_zero_points.has_value()) {
238238
zero_point_data = opt_zero_points.value().const_data_ptr<int64_t>();
@@ -264,7 +264,7 @@ Tensor& dequantize_per_channel_out(
264264
size_t numel, size_t stride, size_t base_ix) { \
265265
for (size_t i = 0; i < numel; i++) { \
266266
size_t current_ix = base_ix * stride + i; \
267-
float _scale = scale_data[current_ix]; \
267+
float _scale = static_cast<float>(scale_data[current_ix]); \
268268
int64_t zero_point = 0; \
269269
if (zero_point_data != nullptr) { \
270270
zero_point = zero_point_data[current_ix]; \
@@ -280,7 +280,7 @@ Tensor& dequantize_per_channel_out(
280280
break; \
281281
} \
282282
for (size_t channel_ix = 0; channel_ix < input.size(axis); ++channel_ix) { \
283-
float _scale = scale_data[channel_ix]; \
283+
float _scale = static_cast<float>(scale_data[channel_ix]); \
284284
int64_t _zero_point = 0; \
285285
if (zero_point_data != nullptr) { \
286286
_zero_point = zero_point_data[channel_ix]; \

0 commit comments

Comments
 (0)