@@ -92,7 +92,8 @@ void check_dequantize_per_tensor_args(
9292} // namespace
9393
9494/* Local function which calls the kernels based on the input datatype */
95- Tensor & dequantize_impl (KernelRuntimeContext& ctx,
95+ Tensor& dequantize_impl (
96+ KernelRuntimeContext& ctx,
9697 Tensor& out,
9798 const Tensor& input,
9899 float * scale_data,
@@ -132,82 +133,82 @@ Tensor & dequantize_impl(KernelRuntimeContext& ctx,
132133 if (is_asym_dequant) {
133134 if (input.scalar_type () == ScalarType::Byte) {
134135 const uint8_t * input_data = input.const_data_ptr <uint8_t >();
135- XT_KERNEL_CHECK (
136+ XT_KERNEL_CHECK (
136137 ctx,
137138 out,
138- xa_nn_elm_dequantize_asym8u_f32,
139- out_data,
140- input_data,
141- inp_shape,
142- input.dim (),
143- axis,
144- zero_point_data,
145- scale_data);
139+ xa_nn_elm_dequantize_asym8u_f32,
140+ out_data,
141+ input_data,
142+ inp_shape,
143+ input.dim (),
144+ axis,
145+ zero_point_data,
146+ scale_data);
146147 } else if (input.scalar_type () == ScalarType::Char) {
147148 const int8_t * input_data = input.const_data_ptr <int8_t >();
148- XT_KERNEL_CHECK (
149+ XT_KERNEL_CHECK (
149150 ctx,
150151 out,
151- xa_nn_elm_dequantize_asym8_f32,
152- out_data,
153- input_data,
154- inp_shape,
155- input.dim (),
156- axis,
157- zero_point_data,
158- scale_data);
152+ xa_nn_elm_dequantize_asym8_f32,
153+ out_data,
154+ input_data,
155+ inp_shape,
156+ input.dim (),
157+ axis,
158+ zero_point_data,
159+ scale_data);
159160 } else if (input.scalar_type () == (ScalarType)Ushort) {
160161 const uint16_t * input_data = input.const_data_ptr <uint16_t >();
161- XT_KERNEL_CHECK (
162+ XT_KERNEL_CHECK (
162163 ctx,
163164 out,
164- xa_nn_elm_dequantize_asym16u_f32,
165- out_data,
166- input_data,
167- inp_shape,
168- input.dim (),
169- axis,
170- zero_point_data,
171- scale_data);
165+ xa_nn_elm_dequantize_asym16u_f32,
166+ out_data,
167+ input_data,
168+ inp_shape,
169+ input.dim (),
170+ axis,
171+ zero_point_data,
172+ scale_data);
172173 } else if (input.scalar_type () == ScalarType::Short) {
173174 const int16_t * input_data = input.const_data_ptr <int16_t >();
174- XT_KERNEL_CHECK (
175+ XT_KERNEL_CHECK (
175176 ctx,
176177 out,
177- xa_nn_elm_dequantize_asym16_f32,
178- out_data,
179- input_data,
180- inp_shape,
181- input.dim (),
182- axis,
183- zero_point_data,
184- scale_data);
178+ xa_nn_elm_dequantize_asym16_f32,
179+ out_data,
180+ input_data,
181+ inp_shape,
182+ input.dim (),
183+ axis,
184+ zero_point_data,
185+ scale_data);
185186 } else if (input.scalar_type () == (ScalarType)Bits4u) {
186187 const uint8_t * input_data = input.const_data_ptr <uint8_t >();
187- XT_KERNEL_CHECK (
188+ XT_KERNEL_CHECK (
188189 ctx,
189190 out,
190- xa_nn_elm_dequantize_asym4u_f32,
191- out_data,
192- input_data,
193- inp_shape,
194- input.dim (),
195- axis,
196- zero_point_data,
197- scale_data);
191+ xa_nn_elm_dequantize_asym4u_f32,
192+ out_data,
193+ input_data,
194+ inp_shape,
195+ input.dim (),
196+ axis,
197+ zero_point_data,
198+ scale_data);
198199 } else if (input.scalar_type () == (ScalarType)Bits4) {
199200 const int8_t * input_data = input.const_data_ptr <int8_t >();
200- XT_KERNEL_CHECK (
201+ XT_KERNEL_CHECK (
201202 ctx,
202203 out,
203- xa_nn_elm_dequantize_asym4_f32,
204- out_data,
205- input_data,
206- inp_shape,
207- input.dim (),
208- axis,
209- zero_point_data,
210- scale_data);
204+ xa_nn_elm_dequantize_asym4_f32,
205+ out_data,
206+ input_data,
207+ inp_shape,
208+ input.dim (),
209+ axis,
210+ zero_point_data,
211+ scale_data);
211212 } else {
212213 if (axis == NULL ) {
213214// calculate the dequantized output, cast scale to float to match fbgemm
@@ -343,10 +344,10 @@ Tensor & dequantize_impl(KernelRuntimeContext& ctx,
343344 } else {
344345 if (input.scalar_type () == ScalarType::Byte) {
345346 const uint8_t * input_data = input.const_data_ptr <uint8_t >();
346- XT_KERNEL_CHECK (
347+ XT_KERNEL_CHECK (
347348 ctx,
348349 out,
349- xa_nn_elm_dequantize_sym8u_f32,
350+ xa_nn_elm_dequantize_sym8u_f32,
350351 out_data,
351352 input_data,
352353 inp_shape,
@@ -358,19 +359,19 @@ Tensor & dequantize_impl(KernelRuntimeContext& ctx,
358359 XT_KERNEL_CHECK (
359360 ctx,
360361 out,
361- xa_nn_elm_dequantize_sym8_f32,
362- out_data,
363- input_data,
364- inp_shape,
365- input.dim (),
366- axis,
367- scale_data);
362+ xa_nn_elm_dequantize_sym8_f32,
363+ out_data,
364+ input_data,
365+ inp_shape,
366+ input.dim (),
367+ axis,
368+ scale_data);
368369 } else if (input.scalar_type () == (ScalarType)Ushort) {
369370 const uint16_t * input_data = input.const_data_ptr <uint16_t >();
370- XT_KERNEL_CHECK (
371+ XT_KERNEL_CHECK (
371372 ctx,
372373 out,
373- xa_nn_elm_dequantize_sym16u_f32,
374+ xa_nn_elm_dequantize_sym16u_f32,
374375 out_data,
375376 input_data,
376377 inp_shape,
@@ -379,10 +380,10 @@ Tensor & dequantize_impl(KernelRuntimeContext& ctx,
379380 scale_data);
380381 } else if (input.scalar_type () == ScalarType::Short) {
381382 const int16_t * input_data = input.const_data_ptr <int16_t >();
382- XT_KERNEL_CHECK (
383+ XT_KERNEL_CHECK (
383384 ctx,
384385 out,
385- xa_nn_elm_dequantize_sym16_f32,
386+ xa_nn_elm_dequantize_sym16_f32,
386387 out_data,
387388 input_data,
388389 inp_shape,
@@ -391,10 +392,10 @@ Tensor & dequantize_impl(KernelRuntimeContext& ctx,
391392 scale_data);
392393 } else if (input.scalar_type () == (ScalarType)Bits4u) {
393394 const uint8_t * input_data = input.const_data_ptr <uint8_t >();
394- XT_KERNEL_CHECK (
395+ XT_KERNEL_CHECK (
395396 ctx,
396397 out,
397- xa_nn_elm_dequantize_sym4u_f32,
398+ xa_nn_elm_dequantize_sym4u_f32,
398399 out_data,
399400 input_data,
400401 inp_shape,
@@ -403,10 +404,10 @@ Tensor & dequantize_impl(KernelRuntimeContext& ctx,
403404 scale_data);
404405 } else if (input.scalar_type () == (ScalarType)Bits4) {
405406 const int8_t * input_data = input.const_data_ptr <int8_t >();
406- XT_KERNEL_CHECK (
407+ XT_KERNEL_CHECK (
407408 ctx,
408409 out,
409- xa_nn_elm_dequantize_sym4_f32,
410+ xa_nn_elm_dequantize_sym4_f32,
410411 out_data,
411412 input_data,
412413 inp_shape,
@@ -558,7 +559,8 @@ Tensor & dequantize_impl(KernelRuntimeContext& ctx,
558559 * https://github.com/pytorch/pytorch/pull/87093#discussion_r1000841181 for more
559560 * info.
560561 */
561- Tensor& dequantize_per_tensor_out (KernelRuntimeContext& context,
562+ Tensor& dequantize_per_tensor_out (
563+ KernelRuntimeContext& context,
562564 const Tensor& input,
563565 double scale,
564566 int64_t zero_point,
@@ -572,20 +574,22 @@ Tensor& dequantize_per_tensor_out(KernelRuntimeContext& context,
572574 ET_CHECK_MSG (
573575 err == torch::executor::Error::Ok,
574576 " Failed to resize out Tensor in dequantize_per_tensor_out" );
575-
577+
576578 check_dequantize_per_tensor_args (
577579 input, quant_min, quant_max, dtype, out_dtype, out);
578580#endif
579581
580582 float scale_data = (float )scale;
581583 int zero_point_data = (int )zero_point;
582584
583- dequantize_impl (context, out, input, &scale_data, &zero_point_data, NULL , out_dtype);
585+ dequantize_impl (
586+ context, out, input, &scale_data, &zero_point_data, NULL , out_dtype);
584587
585588 return out;
586589}
587590
588- Tensor& dequantize_per_tensor_tensor_args_out (KernelRuntimeContext& context,
591+ Tensor& dequantize_per_tensor_tensor_args_out (
592+ KernelRuntimeContext& context,
589593 const Tensor& input,
590594 const Tensor& scale,
591595 const Tensor& zero_point,
@@ -613,7 +617,8 @@ Tensor& dequantize_per_tensor_tensor_args_out(KernelRuntimeContext& context,
613617 ssize_t (zero_point.numel ()));
614618#endif
615619
616- dequantize_per_tensor_out (context,
620+ dequantize_per_tensor_out (
621+ context,
617622 input,
618623 scale.const_data_ptr <double >()[0 ],
619624 zero_point.const_data_ptr <int64_t >()[0 ],
@@ -626,7 +631,8 @@ Tensor& dequantize_per_tensor_tensor_args_out(KernelRuntimeContext& context,
626631 return out;
627632}
628633
629- Tensor& dequantize_per_channel_out (KernelRuntimeContext& context,
634+ Tensor& dequantize_per_channel_out (
635+ KernelRuntimeContext& context,
630636 const Tensor& input,
631637 const Tensor& scale,
632638 const exec_aten::optional<Tensor>& opt_zero_points,
@@ -636,14 +642,13 @@ Tensor& dequantize_per_channel_out(KernelRuntimeContext& context,
636642 ScalarType dtype,
637643 exec_aten::optional<ScalarType> out_dtype,
638644 Tensor& out) {
639-
640645 if (axis < 0 ) {
641646 axis += executorch::runtime::nonzero_dim (input);
642647 }
643- /* if the arguments are passed properly to the operator disable the Macro - "OP_ARG_CHECK"
644- * if not the case, enable the Macro - "OP_ARG_CHECK", to have the checks only in
645- * operator level(As there are no checks in kernel).
646- */
648+ /* if the arguments are passed properly to the operator disable the Macro -
649+ * "OP_ARG_CHECK" if not the case, enable the Macro - "OP_ARG_CHECK", to have
650+ * the checks only in operator level(As there are no checks in kernel).
651+ */
647652#ifdef OP_ARG_CHECK
648653 torch::executor::Error err = resize_tensor (out, input.sizes ());
649654
@@ -705,12 +710,14 @@ Tensor& dequantize_per_channel_out(KernelRuntimeContext& context,
705710 for (int i = 0 ; i < scale.numel (); i++) {
706711 scale_data[i] = (float )scale_dt[i];
707712 }
708- dequantize_impl (context, out, input, scale_data, zero_point_ptr, axis_ptr, out_dtype);
713+ dequantize_impl (
714+ context, out, input, scale_data, zero_point_ptr, axis_ptr, out_dtype);
709715
710716 return out;
711717}
712718
713- Tensor& dequantize_per_token_out (KernelRuntimeContext& context,
719+ Tensor& dequantize_per_token_out (
720+ KernelRuntimeContext& context,
714721 const Tensor& input,
715722 const Tensor& scale,
716723 const Tensor& zero_points,
@@ -757,7 +764,8 @@ Tensor& dequantize_per_token_out(KernelRuntimeContext& context,
757764 " Failed to resize out Tensor in dequantize_per_channel_out" );
758765#endif
759766
760- return dequantize_per_channel_out (context,
767+ return dequantize_per_channel_out (
768+ context,
761769 reshaped_input,
762770 scale,
763771 zero_points,
0 commit comments