@@ -56,27 +56,57 @@ Tensor& cat_out(
5656 const size_t ninputs = tensors.size ();
5757
5858 const auto out_type = out.scalar_type ();
59- ET_SWITCH_REALHBBF16_TYPES (out_type, ctx, " cat.out" , CTYPE_OUT, [&] {
60- CTYPE_OUT* out_ptr = out.mutable_data_ptr <CTYPE_OUT>();
61- for (size_t i = 0 ; i < outer; ++i) {
62- for (size_t j = 0 ; j < ninputs; ++j) {
63- const auto in_type = tensors[j].scalar_type ();
64- ET_SWITCH_REALHBBF16_TYPES (in_type, ctx, " cat.out" , CTYPE_IN, [&] {
59+ const bool out_is_complex =
60+ executorch::runtime::isComplexType (out.scalar_type ());
61+
62+ if (out_is_complex) {
63+ // All the input tensors and output must have same dtype
64+ for (size_t i = 0 ; i < ninputs; ++i) {
65+ const auto in_type = tensors[i].scalar_type ();
66+ ET_KERNEL_CHECK (ctx, out_type == in_type, InvalidArgument, out);
67+ }
68+ ET_SWITCH_COMPLEXH_TYPES (out_type, ctx, " cat.out" , CTYPE, [&] {
69+ CTYPE* out_ptr = out.mutable_data_ptr <CTYPE>();
70+ for (size_t i = 0 ; i < outer; ++i) {
71+ for (size_t j = 0 ; j < ninputs; ++j) {
6572 if (tensors[j].numel () == 0 ) {
6673 return ;
6774 }
6875 size_t inner = tensors[j].size (dim) * dim_stride;
69- const CTYPE_IN* const in_ptr =
70- tensors[j].const_data_ptr <CTYPE_IN>() + i * inner;
71-
72- for (size_t k = 0 ; k < inner; ++k) {
73- out_ptr[k] = static_cast <CTYPE_OUT>(in_ptr[k]);
74- }
76+ const CTYPE* const in_ptr =
77+ tensors[j].const_data_ptr <CTYPE>() + i * inner;
78+ memcpy (out_ptr, in_ptr, inner * sizeof (CTYPE));
7579 out_ptr += inner;
76- });
80+ }
7781 }
78- }
79- });
82+ });
83+ } else {
84+ ET_SWITCH_REALHBBF16_TYPES (out_type, ctx, " cat.out" , CTYPE_OUT, [&] {
85+ CTYPE_OUT* out_ptr = out.mutable_data_ptr <CTYPE_OUT>();
86+ for (size_t i = 0 ; i < outer; ++i) {
87+ for (size_t j = 0 ; j < ninputs; ++j) {
88+ const auto in_type = tensors[j].scalar_type ();
89+ ET_SWITCH_REALHBBF16_TYPES (in_type, ctx, " cat.out" , CTYPE_IN, [&] {
90+ if (tensors[j].numel () == 0 ) {
91+ return ;
92+ }
93+ size_t inner = tensors[j].size (dim) * dim_stride;
94+ const CTYPE_IN* const in_ptr =
95+ tensors[j].const_data_ptr <CTYPE_IN>() + i * inner;
96+
97+ if (sizeof (CTYPE_IN) == sizeof (CTYPE_OUT)) {
98+ memcpy (out_ptr, in_ptr, inner * sizeof (CTYPE_IN));
99+ } else {
100+ for (size_t k = 0 ; k < inner; ++k) {
101+ out_ptr[k] = static_cast <CTYPE_OUT>(in_ptr[k]);
102+ }
103+ }
104+ out_ptr += inner;
105+ });
106+ }
107+ }
108+ });
109+ }
80110
81111 return out;
82112}
0 commit comments