@@ -56,27 +56,60 @@ Tensor& cat_out(
5656 const size_t ninputs = tensors.size ();
5757
5858 const auto out_type = out.scalar_type ();
59- ET_SWITCH_REALHBBF16_TYPES (out_type, ctx, " cat.out" , CTYPE_OUT, [&] {
60- CTYPE_OUT* out_ptr = out.mutable_data_ptr <CTYPE_OUT>();
61- for (size_t i = 0 ; i < outer; ++i) {
62- for (size_t j = 0 ; j < ninputs; ++j) {
63- const auto in_type = tensors[j].scalar_type ();
64- ET_SWITCH_REALHBBF16_TYPES (in_type, ctx, " cat.out" , CTYPE_IN, [&] {
65- if (tensors[j].numel () == 0 ) {
66- return ;
67- }
68- size_t inner = tensors[j].size (dim) * dim_stride;
69- const CTYPE_IN* const in_ptr =
70- tensors[j].const_data_ptr <CTYPE_IN>() + i * inner;
59+ const bool out_is_complex =
60+ executorch::runtime::isComplexType (out.scalar_type ());
7161
72- for (size_t k = 0 ; k < inner; ++k) {
73- out_ptr[k] = static_cast <CTYPE_OUT>(in_ptr[k]);
74- }
75- out_ptr += inner;
76- });
77- }
62+ if (out_is_complex) {
63+ // All the input tensors and output must have same dtype
64+ for (size_t i = 0 ; i < ninputs; ++i) {
65+ const auto in_type = tensors[i].scalar_type ();
66+ ET_KERNEL_CHECK (ctx, out_type == in_type, InvalidArgument, out);
7867 }
79- });
68+ ET_SWITCH_COMPLEXH_TYPES (out_type, ctx, " cat.out" , CTYPE_OUT, [&] {
69+ CTYPE_OUT* out_ptr = out.mutable_data_ptr <CTYPE_OUT>();
70+ for (size_t i = 0 ; i < outer; ++i) {
71+ for (size_t j = 0 ; j < ninputs; ++j) {
72+ const auto in_type = tensors[j].scalar_type ();
73+ ET_SWITCH_COMPLEXH_TYPES (in_type, ctx, " cat.out" , CTYPE_IN, [&] {
74+ if (tensors[j].numel () == 0 ) {
75+ return ;
76+ }
77+ size_t inner = tensors[j].size (dim) * dim_stride;
78+ const CTYPE_IN* const in_ptr =
79+ tensors[j].const_data_ptr <CTYPE_IN>() + i * inner;
80+ memcpy (out_ptr, in_ptr, inner * sizeof (CTYPE_IN));
81+ out_ptr += inner;
82+ });
83+ }
84+ }
85+ });
86+ } else {
87+ ET_SWITCH_REALHBBF16_TYPES (out_type, ctx, " cat.out" , CTYPE_OUT, [&] {
88+ CTYPE_OUT* out_ptr = out.mutable_data_ptr <CTYPE_OUT>();
89+ for (size_t i = 0 ; i < outer; ++i) {
90+ for (size_t j = 0 ; j < ninputs; ++j) {
91+ const auto in_type = tensors[j].scalar_type ();
92+ ET_SWITCH_REALHBBF16_TYPES (in_type, ctx, " cat.out" , CTYPE_IN, [&] {
93+ if (tensors[j].numel () == 0 ) {
94+ return ;
95+ }
96+ size_t inner = tensors[j].size (dim) * dim_stride;
97+ const CTYPE_IN* const in_ptr =
98+ tensors[j].const_data_ptr <CTYPE_IN>() + i * inner;
99+
100+ if (sizeof (CTYPE_IN) == sizeof (CTYPE_OUT)) {
101+ memcpy (out_ptr, in_ptr, inner * sizeof (CTYPE_IN));
102+ } else {
103+ for (size_t k = 0 ; k < inner; ++k) {
104+ out_ptr[k] = static_cast <CTYPE_OUT>(in_ptr[k]);
105+ }
106+ }
107+ out_ptr += inner;
108+ });
109+ }
110+ }
111+ });
112+ }
80113
81114 return out;
82115}
0 commit comments