@@ -56,27 +56,58 @@ Tensor& cat_out(
56
56
const size_t ninputs = tensors.size ();
57
57
58
58
const auto out_type = out.scalar_type ();
59
- ET_SWITCH_REALHBBF16_TYPES (out_type, ctx, " cat.out" , CTYPE_OUT, [&] {
60
- CTYPE_OUT* out_ptr = out.mutable_data_ptr <CTYPE_OUT>();
61
- for (size_t i = 0 ; i < outer; ++i) {
62
- for (size_t j = 0 ; j < ninputs; ++j) {
63
- const auto in_type = tensors[j].scalar_type ();
64
- ET_SWITCH_REALHBBF16_TYPES (in_type, ctx, " cat.out" , CTYPE_IN, [&] {
59
+ const bool out_is_complex =
60
+ executorch::runtime::isComplexType (out.scalar_type ());
61
+
62
+ if (out_is_complex) {
63
+ // TODO: The current support for complex dtype enforces that input and
64
+ // output tensors have the same dtype. Support mixed dtypes in the future.
65
+ for (size_t i = 0 ; i < ninputs; ++i) {
66
+ const auto in_type = tensors[i].scalar_type ();
67
+ ET_KERNEL_CHECK (ctx, out_type == in_type, InvalidArgument, out);
68
+ }
69
+ ET_SWITCH_COMPLEXH_TYPES (out_type, ctx, " cat.out" , CTYPE, [&] {
70
+ CTYPE* out_ptr = out.mutable_data_ptr <CTYPE>();
71
+ for (size_t i = 0 ; i < outer; ++i) {
72
+ for (size_t j = 0 ; j < ninputs; ++j) {
65
73
if (tensors[j].numel () == 0 ) {
66
74
return ;
67
75
}
68
76
size_t inner = tensors[j].size (dim) * dim_stride;
69
- const CTYPE_IN* const in_ptr =
70
- tensors[j].const_data_ptr <CTYPE_IN>() + i * inner;
71
-
72
- for (size_t k = 0 ; k < inner; ++k) {
73
- out_ptr[k] = static_cast <CTYPE_OUT>(in_ptr[k]);
74
- }
77
+ const CTYPE* const in_ptr =
78
+ tensors[j].const_data_ptr <CTYPE>() + i * inner;
79
+ memcpy (out_ptr, in_ptr, inner * sizeof (CTYPE));
75
80
out_ptr += inner;
76
- });
81
+ }
77
82
}
78
- }
79
- });
83
+ });
84
+ } else {
85
+ ET_SWITCH_REALHBBF16_TYPES (out_type, ctx, " cat.out" , CTYPE_OUT, [&] {
86
+ CTYPE_OUT* out_ptr = out.mutable_data_ptr <CTYPE_OUT>();
87
+ for (size_t i = 0 ; i < outer; ++i) {
88
+ for (size_t j = 0 ; j < ninputs; ++j) {
89
+ const auto in_type = tensors[j].scalar_type ();
90
+ ET_SWITCH_REALHBBF16_TYPES (in_type, ctx, " cat.out" , CTYPE_IN, [&] {
91
+ if (tensors[j].numel () == 0 ) {
92
+ return ;
93
+ }
94
+ size_t inner = tensors[j].size (dim) * dim_stride;
95
+ const CTYPE_IN* const in_ptr =
96
+ tensors[j].const_data_ptr <CTYPE_IN>() + i * inner;
97
+
98
+ if (sizeof (CTYPE_IN) == sizeof (CTYPE_OUT)) {
99
+ memcpy (out_ptr, in_ptr, inner * sizeof (CTYPE_IN));
100
+ } else {
101
+ for (size_t k = 0 ; k < inner; ++k) {
102
+ out_ptr[k] = static_cast <CTYPE_OUT>(in_ptr[k]);
103
+ }
104
+ }
105
+ out_ptr += inner;
106
+ });
107
+ }
108
+ }
109
+ });
110
+ }
80
111
81
112
return out;
82
113
}
0 commit comments