@@ -56,27 +56,58 @@ Tensor& cat_out(
5656 const size_t ninputs = tensors.size ();
5757
5858 const auto out_type = out.scalar_type ();
59- ET_SWITCH_REALHBBF16_TYPES (out_type, ctx, " cat.out" , CTYPE_OUT, [&] {
60- CTYPE_OUT* out_ptr = out.mutable_data_ptr <CTYPE_OUT>();
61- for (size_t i = 0 ; i < outer; ++i) {
62- for (size_t j = 0 ; j < ninputs; ++j) {
63- const auto in_type = tensors[j].scalar_type ();
64- ET_SWITCH_REALHBBF16_TYPES (in_type, ctx, " cat.out" , CTYPE_IN, [&] {
59+ const bool out_is_complex =
60+ executorch::runtime::isComplexType (out.scalar_type ());
61+
62+ if (out_is_complex) {
63+ // TODO: The current support for complex dtype enforces that input and
64+ // output tensors have the same dtype. Support mixed dtypes in the future.
65+ for (size_t i = 0 ; i < ninputs; ++i) {
66+ const auto in_type = tensors[i].scalar_type ();
67+ ET_KERNEL_CHECK (ctx, out_type == in_type, InvalidArgument, out);
68+ }
69+ ET_SWITCH_COMPLEXH_TYPES (out_type, ctx, " cat.out" , CTYPE, [&] {
70+ CTYPE* out_ptr = out.mutable_data_ptr <CTYPE>();
71+ for (size_t i = 0 ; i < outer; ++i) {
72+ for (size_t j = 0 ; j < ninputs; ++j) {
6573 if (tensors[j].numel () == 0 ) {
6674 return ;
6775 }
6876 size_t inner = tensors[j].size (dim) * dim_stride;
69- const CTYPE_IN* const in_ptr =
70- tensors[j].const_data_ptr <CTYPE_IN>() + i * inner;
71-
72- for (size_t k = 0 ; k < inner; ++k) {
73- out_ptr[k] = static_cast <CTYPE_OUT>(in_ptr[k]);
74- }
77+ const CTYPE* const in_ptr =
78+ tensors[j].const_data_ptr <CTYPE>() + i * inner;
79+ memcpy (out_ptr, in_ptr, inner * sizeof (CTYPE));
7580 out_ptr += inner;
76- });
81+ }
7782 }
78- }
79- });
83+ });
84+ } else {
85+ ET_SWITCH_REALHBBF16_TYPES (out_type, ctx, " cat.out" , CTYPE_OUT, [&] {
86+ CTYPE_OUT* out_ptr = out.mutable_data_ptr <CTYPE_OUT>();
87+ for (size_t i = 0 ; i < outer; ++i) {
88+ for (size_t j = 0 ; j < ninputs; ++j) {
89+ const auto in_type = tensors[j].scalar_type ();
90+ ET_SWITCH_REALHBBF16_TYPES (in_type, ctx, " cat.out" , CTYPE_IN, [&] {
91+ if (tensors[j].numel () == 0 ) {
92+ return ;
93+ }
94+ size_t inner = tensors[j].size (dim) * dim_stride;
95+ const CTYPE_IN* const in_ptr =
96+ tensors[j].const_data_ptr <CTYPE_IN>() + i * inner;
97+
98+ if (sizeof (CTYPE_IN) == sizeof (CTYPE_OUT)) {
99+ memcpy (out_ptr, in_ptr, inner * sizeof (CTYPE_IN));
100+ } else {
101+ for (size_t k = 0 ; k < inner; ++k) {
102+ out_ptr[k] = static_cast <CTYPE_OUT>(in_ptr[k]);
103+ }
104+ }
105+ out_ptr += inner;
106+ });
107+ }
108+ }
109+ });
110+ }
80111
81112 return out;
82113}
0 commit comments