66 * LICENSE file in the root directory of this source tree.
77 */
88
9- #include < executorch/backends/cadence/fusion_g3/operators/tensor_util.h>
9+ #include < executorch/backends/cadence/fusion_g3/operators/operators.h>
10+
11+ #include < xa_nnlib_kernels_api.h>
12+
13+ #include < executorch/backends/cadence/fusion_g3/operators/xt_macros.h>
1014#include < executorch/kernels/portable/cpu/util/reduce_util.h>
1115#include < executorch/runtime/kernel/kernel_includes.h>
12- #include < xa_nnlib_kernels_api.h>
1316#include < algorithm>
1417#include < cinttypes>
1518#include < cmath>
1619
17- using exec_aten ::Scalar;
18- using exec_aten ::ScalarType;
19- using exec_aten ::Tensor;
20- using torch::executor ::Error;
21- using torch::executor ::KernelRuntimeContext;
20+ using ::executorch::aten ::Scalar;
21+ using ::executorch::aten ::ScalarType;
22+ using ::executorch::aten ::Tensor;
23+ using ::executorch::runtime ::Error;
24+ using ::executorch::runtime ::KernelRuntimeContext;
2225
2326template <typename T>
24- using optional = exec_aten ::optional<T>;
27+ using optional = ::executorch::aten ::optional<T>;
2528/* ScalarType in Executorch do not have support for below data types.
2629 * So, creating a placeholder for these data types. Once, ScalarTypes is
2730 * updated to have support for below data types, these can be removed and
@@ -48,7 +51,7 @@ void check_dequantize_per_tensor_args(
4851 int64_t quant_min,
4952 int64_t quant_max,
5053 ScalarType dtype,
51- exec_aten ::optional<ScalarType>& out_dtype,
54+ ::executorch::aten ::optional<ScalarType>& out_dtype,
5255 Tensor& out) {
5356 ET_CHECK_MSG (
5457 input.scalar_type () == ScalarType::Byte ||
@@ -91,8 +94,9 @@ Tensor& dequantize_impl(
9194 float * scale_data,
9295 int * zero_point_data,
9396 int * axis,
94- exec_aten::optional<ScalarType> out_dtype) {
95- const exec_aten::ArrayRef<Tensor::SizesType> input_size = input.sizes ();
97+ ::executorch::aten::optional<ScalarType> out_dtype) {
98+ const ::executorch::aten::ArrayRef<Tensor::SizesType> input_size =
99+ input.sizes ();
96100
97101 int kTensorDimensionLimit = 5 ;
98102
@@ -251,8 +255,9 @@ Tensor& dequantize_impl(
251255 }
252256 }
253257
254- exec_aten::optional<exec_aten::ArrayRef<int64_t >> optional_dim_list{
255- exec_aten::ArrayRef<int64_t >{dims, size_t (input.dim () - 1 )}};
258+ ::executorch::aten::optional<::executorch::aten::ArrayRef<int64_t >>
259+ optional_dim_list{::executorch::aten::ArrayRef<int64_t >{
260+ dims, size_t (input.dim () - 1 )}};
256261
257262// Actual dequantization logic
258263// input, out are the input and output tensors
@@ -456,8 +461,9 @@ Tensor& dequantize_impl(
456461 }
457462 }
458463
459- exec_aten::optional<exec_aten::ArrayRef<int64_t >> optional_dim_list{
460- exec_aten::ArrayRef<int64_t >{dims, size_t (input.dim () - 1 )}};
464+ ::executorch::aten::optional<::executorch::aten::ArrayRef<int64_t >>
465+ optional_dim_list{::executorch::aten::ArrayRef<int64_t >{
466+ dims, size_t (input.dim () - 1 )}};
461467
462468// Actual dequantization logic
463469// input, out are the input and output tensors
@@ -559,7 +565,7 @@ Tensor& dequantize_per_tensor_out(
559565 int64_t quant_min,
560566 int64_t quant_max,
561567 ScalarType dtype,
562- exec_aten ::optional<ScalarType> out_dtype,
568+ ::executorch::aten ::optional<ScalarType> out_dtype,
563569 Tensor& out) {
564570#ifdef OP_ARG_CHECK
565571 torch::executor::Error err = resize_tensor (out, input.sizes ());
@@ -588,7 +594,7 @@ Tensor& dequantize_per_tensor_tensor_args_out(
588594 int64_t quant_min,
589595 int64_t quant_max,
590596 ScalarType dtype,
591- exec_aten ::optional<ScalarType> out_dtype,
597+ ::executorch::aten ::optional<ScalarType> out_dtype,
592598 Tensor& out) {
593599#ifdef OP_ARG_CHECK
594600 ET_CHECK_MSG (
@@ -627,12 +633,12 @@ Tensor& dequantize_per_channel_out(
627633 KernelRuntimeContext& context,
628634 const Tensor& input,
629635 const Tensor& scale,
630- const exec_aten ::optional<Tensor>& opt_zero_points,
636+ const ::executorch::aten ::optional<Tensor>& opt_zero_points,
631637 int64_t axis,
632638 int64_t quant_min,
633639 int64_t quant_max,
634640 ScalarType dtype,
635- exec_aten ::optional<ScalarType> out_dtype,
641+ ::executorch::aten ::optional<ScalarType> out_dtype,
636642 Tensor& out) {
637643 if (axis < 0 ) {
638644 axis += executorch::runtime::nonzero_dim (input);
@@ -725,18 +731,18 @@ Tensor& dequantize_per_token_out(
725731 }
726732 // This unfortunate change is needed because we compile op_quantize for aten
727733 // mode as well
728- std::array<exec_aten ::SizesType, 2 > input_sizes;
729- input_sizes[0 ] = static_cast <exec_aten ::SizesType>(num_channels);
734+ std::array<::executorch::aten ::SizesType, 2 > input_sizes;
735+ input_sizes[0 ] = static_cast <::executorch::aten ::SizesType>(num_channels);
730736 input_sizes[1 ] =
731- static_cast <exec_aten ::SizesType>(input.size (input.dim () - 1 ));
737+ static_cast <::executorch::aten ::SizesType>(input.size (input.dim () - 1 ));
732738#ifdef USE_ATEN_LIB
733739 Tensor reshaped_input = at::from_blob (
734740 input.mutable_data_ptr (),
735741 input_sizes,
736742 at::TensorOptions (input.scalar_type ()));
737743#else
738- std::array<exec_aten ::DimOrderType, 2 > input_dim_order{0 , 1 };
739- std::array<exec_aten ::StridesType, 2 > input_strides;
744+ std::array<::executorch::aten ::DimOrderType, 2 > input_dim_order{0 , 1 };
745+ std::array<::executorch::aten ::StridesType, 2 > input_strides;
740746 executorch::runtime::dim_order_to_stride_nocheck (
741747 input_sizes.data (), input_dim_order.data (), 2 , input_strides.data ());
742748 void * input_data = input.mutable_data_ptr ();
@@ -769,22 +775,6 @@ Tensor& dequantize_per_token_out(
769775 out);
770776}
771777
772- Tensor& dequantize_per_token_out (
773- KernelRuntimeContext& context,
774- const Tensor& input,
775- const Tensor& scale,
776- const Tensor& zero_points,
777- int64_t quant_min,
778- int64_t quant_max,
779- ScalarType dtype,
780- ScalarType out_dtype,
781- Tensor& out)
782- {
783- (void )context;
784- return dequantize_per_token_out (
785- input, scale, zero_points, quant_min, quant_max, dtype, out_dtype, out);
786- }
787-
788778} // namespace native
789779} // namespace G3
790780} // namespace impl
0 commit comments