@@ -168,6 +168,19 @@ Tensor& dequantize_per_tensor_tensor_args_out(
168168 return out;
169169}
170170
171+ float get_scale (const Tensor& scale, size_t channel_ix) {
172+ ET_CHECK_MSG (
173+ (scale.scalar_type () == ScalarType::Double) ||
174+ (scale.scalar_type () == ScalarType::Float),
175+ " scale.scalar_type() %" PRId8 " is not double or float type" ,
176+ static_cast <int8_t >(scale.scalar_type ()));
177+ if (scale.scalar_type () == ScalarType::Double) {
178+ return static_cast <float >(scale.const_data_ptr <double >()[channel_ix]);
179+ } else {
180+ return scale.const_data_ptr <float >()[channel_ix];
181+ }
182+ }
183+
171184Tensor& dequantize_per_channel_out (
172185 const Tensor& input,
173186 const Tensor& scale,
@@ -189,11 +202,6 @@ Tensor& dequantize_per_channel_out(
189202 axis += nonzero_dim (input);
190203 }
191204
192- ET_CHECK_MSG (
193- scale.scalar_type () == ScalarType::Double,
194- " scale.scalar_type() %" PRId8 " is not double type" ,
195- static_cast <int8_t >(scale.scalar_type ()));
196-
197205 ET_CHECK_MSG (
198206 scale.numel () == input.size (axis),
199207 " scale.numel() %zd != input.size(axis) %zd" ,
@@ -226,7 +234,6 @@ Tensor& dequantize_per_channel_out(
226234 dims[i] = i + 1 ;
227235 }
228236 }
229- const double * scale_data = scale.const_data_ptr <double >();
230237 const int64_t * zero_point_data;
231238 if (opt_zero_points.has_value ()) {
232239 zero_point_data = opt_zero_points.value ().const_data_ptr <int64_t >();
@@ -254,11 +261,11 @@ Tensor& dequantize_per_channel_out(
254261 axis == 0 , " Axis must be 0 for a single dimensional tensors" ); \
255262 const optional<int64_t > dim; \
256263 apply_over_dim ( \
257- [input_data_ptr, out_data_ptr, scale_data, zero_point_data]( \
264+ [input_data_ptr, out_data_ptr, zero_point_data, &scale]( \
258265 size_t numel, size_t stride, size_t base_ix) { \
259266 for (size_t i = 0 ; i < numel; i++) { \
260267 size_t current_ix = base_ix * stride + i; \
261- float _scale = static_cast < float >(scale_data[ current_ix]); \
268+ float _scale = get_scale (scale, current_ix); \
262269 int64_t zero_point = 0 ; \
263270 if (zero_point_data != nullptr ) { \
264271 zero_point = zero_point_data[current_ix]; \
@@ -274,7 +281,7 @@ Tensor& dequantize_per_channel_out(
274281 break ; \
275282 } \
276283 for (size_t channel_ix = 0 ; channel_ix < input.size (axis); ++channel_ix) { \
277- float _scale = static_cast < float >(scale_data[ channel_ix]); \
284+ float _scale = get_scale (scale, channel_ix); \
278285 int64_t _zero_point = 0 ; \
279286 if (zero_point_data != nullptr ) { \
280287 _zero_point = zero_point_data[channel_ix]; \
0 commit comments