66 * LICENSE file in the root directory of this source tree.
77 */
88
9+ #include < executorch/backends/cadence/fusion_g3/operators/operators.h>
10+ #include < executorch/backends/cadence/fusion_g3/operators/xt_utils.h>
11+
912#include < cstring>
1013
1114#include < xa_nnlib_kernels_api.h>
1215
16+ #include < executorch/backends/cadence/fusion_g3/operators/xt_macros.h>
1317#include < executorch/kernels/portable/cpu/util/copy_ops_util.h>
1418#include < executorch/runtime/kernel/kernel_includes.h>
1519
20+ using ::executorch::aten::ArrayRef;
1621using ::executorch::aten::ScalarType;
1722using ::executorch::aten::Tensor;
1823using ::executorch::runtime::Error;
@@ -23,7 +28,6 @@ using ::executorch::runtime::KernelRuntimeContext;
2328 * updated to have support for below data types, these can be removed and
2429 * operator need to be updated accordingly
2530 */
26- enum datatype { Ushort = 20 , Uint = 23 };
2731
2832namespace cadence {
2933namespace impl {
@@ -32,20 +36,22 @@ namespace native {
3236
3337Tensor& cat_out (
3438 KernelRuntimeContext& ctx,
35- exec_aten:: ArrayRef<Tensor> tensors,
39+ ArrayRef<Tensor> tensors,
3640 int64_t dim,
3741 Tensor& out) {
3842 if (dim < 0 ) {
3943 dim += out.dim ();
4044 }
4145
46+ int kTensorDimensionLimit = executorch::runtime::kTensorDimensionLimit ;
47+
48+ #ifdef OP_ARG_CHECK
4249 ET_KERNEL_CHECK (
4350 ctx,
4451 torch::executor::check_cat_args (tensors, dim, out),
4552 InvalidArgument,
4653 out);
4754
48- int kTensorDimensionLimit = executorch::runtime::kTensorDimensionLimit ;
4955 Tensor::SizesType expected_out_size[kTensorDimensionLimit ];
5056 size_t expected_out_dim = 0 ;
5157 torch::executor::get_cat_out_target_size (
@@ -57,14 +63,28 @@ Tensor& cat_out(
5763 out, {expected_out_size, expected_out_dim}) == Error::Ok,
5864 InvalidArgument,
5965 out);
66+ #endif
67+ // Special handling when all inputs are 1D-empty tensors for aten
68+ // consistency In that case, just return an 1D-empty tensor without checking
69+ // dim
70+ bool all_1d_empty = true ;
71+ for (size_t i = 0 ; i < tensors.size (); ++i) {
72+ if (tensors[i].numel () != 0 || tensors[i].dim () != 1 ) {
73+ all_1d_empty = false ;
74+ break ;
75+ }
76+ }
77+ if (all_1d_empty) {
78+ return out;
79+ }
6080
6181 const signed char * inp_tensors[tensors.size ()];
6282 const int * inp_tensors_shapes[tensors.size ()];
6383
6484 int inp_shapes_size[tensors.size ()];
6585
6686 int temp_sizes[tensors.size ()][kTensorDimensionLimit ];
67- exec_aten:: ArrayRef<Tensor::SizesType> temp_size;
87+ ArrayRef<Tensor::SizesType> temp_size;
6888
6989 for (int i = 0 ; i < tensors.size (); i++) {
7090 inp_tensors[i] = tensors[i].const_data_ptr <signed char >();
@@ -79,88 +99,32 @@ Tensor& cat_out(
7999
80100 signed char * out_data = out.mutable_data_ptr <signed char >();
81101
82- const exec_aten:: ArrayRef<Tensor::SizesType> out_size = out.sizes ();
102+ const ArrayRef<Tensor::SizesType> out_size = out.sizes ();
83103 int out_shapes[kTensorDimensionLimit ];
84104 for (int i = 0 ; i < out_size.size (); i++) // output shapes
85105 {
86106 out_shapes[i] = out_size[i];
87107 }
88108
89- if (out.scalar_type () == ScalarType::Int) {
90- xa_nn_cat (
91- out_data,
92- out_shapes,
93- inp_tensors,
94- inp_tensors_shapes,
95- inp_shapes_size[0 ],
96- tensors.size (),
97- (int )dim,
98- sizeof (int ));
99- } else if (out.scalar_type () == ScalarType::Short) {
100- xa_nn_cat (
101- out_data,
102- out_shapes,
103- inp_tensors,
104- inp_tensors_shapes,
105- inp_shapes_size[0 ],
106- tensors.size (),
107- (int )dim,
108- sizeof (short ));
109- } else if (out.scalar_type () == ScalarType::Char) {
110- xa_nn_cat (
111- out_data,
112- out_shapes,
113- inp_tensors,
114- inp_tensors_shapes,
115- inp_shapes_size[0 ],
116- tensors.size (),
117- (int )dim,
118- sizeof (char ));
119- } else if (out.scalar_type () == (ScalarType)Uint) {
120- xa_nn_cat (
121- out_data,
122- out_shapes,
123- inp_tensors,
124- inp_tensors_shapes,
125- inp_shapes_size[0 ],
126- tensors.size (),
127- (int )dim,
128- sizeof (int ));
129- } else if (out.scalar_type () == (ScalarType)Ushort) {
130- xa_nn_cat (
109+ if ((out.scalar_type () == ScalarType::Int) ||
110+ (out.scalar_type () == ScalarType::Short) ||
111+ (out.scalar_type () == ScalarType::Char) ||
112+ (out.scalar_type () == ScalarType::UInt32) ||
113+ (out.scalar_type () == ScalarType::UInt16) ||
114+ (out.scalar_type () == ScalarType::Byte)) {
115+ XT_KERNEL_CHECK (
116+ ctx,
117+ out,
118+ xa_nn_cat,
131119 out_data,
132120 out_shapes,
133121 inp_tensors,
134122 inp_tensors_shapes,
135123 inp_shapes_size[0 ],
136124 tensors.size (),
137125 (int )dim,
138- sizeof (short ));
139- } else if (out.scalar_type () == ScalarType::Byte) {
140- xa_nn_cat (
141- out_data,
142- out_shapes,
143- inp_tensors,
144- inp_tensors_shapes,
145- inp_shapes_size[0 ],
146- tensors.size (),
147- (int )dim,
148- sizeof (char ));
149-
126+ get_element_size (out.scalar_type ()));
150127 } else {
151- // Special handling when all inputs are 1D-empty tensors for aten
152- // consistency In that case, just return an 1D-empty tensor without checking
153- // dim
154- bool all_1d_empty = true ;
155- for (size_t i = 0 ; i < tensors.size (); ++i) {
156- if (tensors[i].numel () != 0 || tensors[i].dim () != 1 ) {
157- all_1d_empty = false ;
158- break ;
159- }
160- }
161- if (all_1d_empty) {
162- return out;
163- }
164128 const size_t outer = executorch::runtime::getLeadingDims (out, dim);
165129 const size_t dim_stride = executorch::runtime::getTrailingDims (out, dim);
166130 const size_t ninputs = tensors.size ();
0 commit comments