@@ -18,9 +18,20 @@ using exec_aten::Tensor;
1818using torch::executor::KernelRuntimeContext;
1919using torch::executor::Error;
2020
21+ /* ScalarType in Executorch do not have support for below data types.
22+ * So, creating a placeholder for these data types. Once, ScalarTypes is
23+ * updated to have support for below data types, these can be removed and
24+ * operator need to be updated accordingly
25+ */
26+ enum datatype {
27+ Ushort = 20 ,
28+ Uint = 23 ,
29+ };
30+
2131
32+ namespace cadence {
2233namespace impl {
23- namespace FusionG3 {
34+ namespace G3 {
2435namespace native {
2536
2637
@@ -95,6 +106,22 @@ Tensor& cat_out(KernelRuntimeContext& ctx,
95106 inp_shapes_size[0 ], tensors.size (), (int )dim, sizeof (char ));
96107
97108 }
109+ if (out.scalar_type () == (ScalarType)Uint)
110+ {
111+ xa_nn_cat (out_data, out_shapes, inp_tensors, inp_tensors_shapes,
112+ inp_shapes_size[0 ], tensors.size (), (int )dim, sizeof (int ));
113+ }
114+ else if (out.scalar_type () == (ScalarType)Ushort)
115+ {
116+ xa_nn_cat (out_data, out_shapes, inp_tensors, inp_tensors_shapes,
117+ inp_shapes_size[0 ], tensors.size (), (int )dim, sizeof (short ));
118+ }
119+ else if (out.scalar_type () == ScalarType::Byte)
120+ {
121+ xa_nn_cat (out_data, out_shapes, inp_tensors, inp_tensors_shapes,
122+ inp_shapes_size[0 ], tensors.size (), (int )dim, sizeof (char ));
123+
124+ }
98125 else
99126 {
100127 // Special handling when all inputs are 1D-empty tensors for aten consistency
@@ -145,5 +172,6 @@ Tensor& cat_out(KernelRuntimeContext& ctx,
145172}
146173
147174} // namespace native
148- } // namespace FusionG3
149- } // namespace impl
175+ } // namespace G3
176+ } // namespace impl
177+ } // namespace cadence
0 commit comments