@@ -196,7 +196,7 @@ Tensor& dequantize_per_channel_out(
196196 " Failed to resize out Tensor in dequantize_per_channel_out" );
197197
198198 ET_CHECK_MSG (
199- scale.scalar_type () == ScalarType::Float ,
199+ scale.scalar_type () == ScalarType::Double ,
200200 " scale.scalar_type() %" PRId8 " is not float type" ,
201201 static_cast <int8_t >(scale.scalar_type ()));
202202
@@ -232,7 +232,7 @@ Tensor& dequantize_per_channel_out(
232232 dims[i] = i + 1 ;
233233 }
234234 }
235- const float * scale_data = scale.const_data_ptr <float >();
235+ const double * scale_data = scale.const_data_ptr <double >();
236236 const int64_t * zero_point_data;
237237 if (opt_zero_points.has_value ()) {
238238 zero_point_data = opt_zero_points.value ().const_data_ptr <int64_t >();
@@ -264,7 +264,7 @@ Tensor& dequantize_per_channel_out(
264264 size_t numel, size_t stride, size_t base_ix) { \
265265 for (size_t i = 0 ; i < numel; i++) { \
266266 size_t current_ix = base_ix * stride + i; \
267- float _scale = scale_data[current_ix]; \
267+ float _scale = static_cast < float >( scale_data[current_ix]); \
268268 int64_t zero_point = 0 ; \
269269 if (zero_point_data != nullptr ) { \
270270 zero_point = zero_point_data[current_ix]; \
@@ -280,7 +280,7 @@ Tensor& dequantize_per_channel_out(
280280 break ; \
281281 } \
282282 for (size_t channel_ix = 0 ; channel_ix < input.size (axis); ++channel_ix) { \
283- float _scale = scale_data[channel_ix]; \
283+ float _scale = static_cast < float >( scale_data[channel_ix]); \
284284 int64_t _zero_point = 0 ; \
285285 if (zero_point_data != nullptr ) { \
286286 _zero_point = zero_point_data[channel_ix]; \
0 commit comments