66 * LICENSE file in the root directory of this source tree.
77 */
88
9+ #include < executorch/kernels/portable/cpu/util/reduce_util.h>
910#include < executorch/runtime/kernel/kernel_includes.h>
1011#include < algorithm>
1112#include < cinttypes>
@@ -29,8 +30,6 @@ namespace {
2930 */
3031void check_dequantize_per_tensor_args (
3132 const Tensor& input,
32- double scale,
33- int64_t zero_point,
3433 int64_t quant_min,
3534 int64_t quant_max,
3635 ScalarType dtype,
@@ -58,9 +57,6 @@ void check_dequantize_per_tensor_args(
5857 " quant min: %" PRId64 " is greater than quant max: %" PRId64,
5958 quant_min,
6059 quant_max);
61-
62- (void )scale;
63- (void )zero_point;
6460}
6561
6662} // namespace
@@ -87,8 +83,7 @@ Tensor& dequantize_per_tensor_out(
8783 err == torch::executor::Error::Ok,
8884 " Failed to resize out Tensor in dequantize_per_tensor_out" );
8985
90- check_dequantize_per_tensor_args (
91- input, scale, zero_point, quant_min, quant_max, dtype, out);
86+ check_dequantize_per_tensor_args (input, quant_min, quant_max, dtype, out);
9287
9388 // calculate the dequantized output, cast scale to float to match fbgemm
9489 // behavior
@@ -162,6 +157,136 @@ Tensor& dequantize_per_tensor_tensor_args_out(
162157 return out;
163158}
164159
160+ Tensor& dequantize_per_channel_out (
161+ const Tensor& input,
162+ const Tensor& scale,
163+ const Tensor& zero_point,
164+ int64_t axis,
165+ int64_t quant_min,
166+ int64_t quant_max,
167+ ScalarType dtype,
168+ Tensor& out) {
169+ torch::executor::Error err = resize_tensor (out, input.sizes ());
170+
171+ // normalize axis
172+ ET_CHECK_MSG (
173+ tensor_has_dim (input, axis),
174+ " axis %zd is not legal it should be -input.dim() <= axis < input.dim() %zd" ,
175+ ssize_t (axis),
176+ ssize_t (input.dim ()));
177+
178+ if (axis < 0 ) {
179+ axis += nonzero_dim (input);
180+ }
181+
182+ ET_CHECK_MSG (
183+ err == torch::executor::Error::Ok,
184+ " Failed to resize out Tensor in dequantize_per_channel_out" );
185+
186+ ET_CHECK_MSG (
187+ scale.scalar_type () == ScalarType::Double,
188+ " scale.scalar_type() %" PRId8 " is not double type" ,
189+ static_cast <int8_t >(scale.scalar_type ()));
190+
191+ ET_CHECK_MSG (
192+ scale.numel () == input.size (axis),
193+ " scale.numel() %zd != input.size(axis) %zd" ,
194+ ssize_t (scale.numel ()),
195+ ssize_t (input.size (axis)));
196+
197+ ET_CHECK_MSG (
198+ zero_point.scalar_type () == ScalarType::Long,
199+ " zero_point.scalar_type() %" PRId8 " is not integer type" ,
200+ static_cast <int8_t >(zero_point.scalar_type ()));
201+
202+ ET_CHECK_MSG (
203+ zero_point.numel () == input.size (axis),
204+ " zero_point.numel() %zd != input.size(axis) %zd" ,
205+ ssize_t (zero_point.numel ()),
206+ ssize_t (input.size (axis)));
207+
208+ check_dequantize_per_tensor_args (input, quant_min, quant_max, dtype, out);
209+
210+ // a list contains all dimensions except axis
211+ int64_t dims[input.dim () - 1 ];
212+ for (int64_t i = 0 ; i < input.dim () - 1 ; i++) {
213+ if (i < axis) {
214+ dims[i] = i;
215+ } else {
216+ dims[i] = i - 1 ;
217+ }
218+ }
219+ const double * scale_data = scale.const_data_ptr <double >();
220+ const int64_t * zero_point_data = zero_point.const_data_ptr <int64_t >();
221+
222+ exec_aten::optional<exec_aten::ArrayRef<int64_t >> optional_dim_list{
223+ exec_aten::ArrayRef<int64_t >{dims, size_t (input.dim () - 1 )}};
224+
225+ // Actual dequantization logic
226+ // input, out are the input and output tensors
227+ // channel_ix is the index along the axis dimension. 0 <= channel_ix <
228+ // input.size(axis).
229+ // i.e. if the tensor has shape (N,C,H,W), axis being 1, then channel_ix
230+ // will be 0, 1, 2, ... C-1
231+ // in_ix is the flat index of the element you are dequantizing.
232+ // in other words you are dequantizing in_data[in_ix]
233+ #define DEQUANTIZE_IMPL (CTYPE_IN, CTYPE_OUT, out_dtype ) \
234+ case ScalarType::out_dtype: \
235+ for (size_t channel_ix = 0 ; channel_ix < input.size (axis); ++channel_ix) { \
236+ double _scale = scale_data[channel_ix]; \
237+ int64_t _zero_point = zero_point_data[channel_ix]; \
238+ apply_over_dim_list ( \
239+ [input, out, _scale, _zero_point](size_t in_ix) { \
240+ out.mutable_data_ptr <CTYPE_OUT>()[in_ix] = static_cast <CTYPE_OUT>( \
241+ (input.const_data_ptr <CTYPE_IN>()[in_ix] - _zero_point) * \
242+ _scale); \
243+ }, \
244+ input, \
245+ optional_dim_list, \
246+ channel_ix); \
247+ } \
248+ break ;
249+ #define CALCULATE_FLOAT_TYPE (CTYPE_IN, in_dtype ) \
250+ case ScalarType::in_dtype: \
251+ switch (out.scalar_type ()) { \
252+ ET_FORALL_FLOAT_TYPES_WITH (CTYPE_IN, DEQUANTIZE_IMPL); \
253+ default : \
254+ ET_CHECK_MSG ( \
255+ false , \
256+ " Unhandled output dtype %" PRId8, \
257+ static_cast <int8_t >(out.scalar_type ())); \
258+ } \
259+ break ;
260+
261+ switch (input.scalar_type ()) {
262+ ET_FORALL_INT_TYPES (CALCULATE_FLOAT_TYPE);
263+ default :
264+ ET_CHECK_MSG (
265+ false ,
266+ " Unhandled input dtype %" PRId8,
267+ static_cast <int8_t >(input.scalar_type ()));
268+ }
269+ #undef CALCULATE_FLOAT_TYPE
270+ #undef QUANTIZE_IMPL
271+
272+ return out;
273+ }
274+
275+ Tensor& dequantize_per_channel_out (
276+ RuntimeContext& context,
277+ const Tensor& input,
278+ const Tensor& scale,
279+ const Tensor& zero_point,
280+ int64_t axis,
281+ int64_t quant_min,
282+ int64_t quant_max,
283+ ScalarType dtype,
284+ Tensor& out) {
285+ (void )context;
286+ return dequantize_per_channel_out (
287+ input, scale, zero_point, axis, quant_min, quant_max, dtype, out);
288+ }
289+
165290Tensor& dequantize_per_tensor_out (
166291 RuntimeContext& context,
167292 const Tensor& input,
0 commit comments