@@ -87,13 +87,13 @@ void check_args(
87
87
88
88
// Validate each output.
89
89
for (size_t i = 0 ; i < out.size (); ++i) {
90
- // All output dtypes must match the input type .
90
+ // All output dtypes must be the same .
91
91
ET_CHECK_MSG (
92
- out[i].scalar_type () == input .scalar_type (),
93
- " out[%zu] dtype %hhd != input dtype %hhd" ,
92
+ out[i].scalar_type () == out[ 0 ] .scalar_type (),
93
+ " out[%zu] dtype %hhd != out[0] dtype %hhd" ,
94
94
i,
95
95
out[i].scalar_type (),
96
- input .scalar_type ());
96
+ out[ 0 ] .scalar_type ());
97
97
98
98
// All outputs must have the same number of dimensions as the input.
99
99
ET_CHECK_MSG (
@@ -170,26 +170,32 @@ void split_copy_Tensor_out(
170
170
171
171
const size_t leading_dims = getLeadingDims (input, dim);
172
172
const size_t trailing_dims = getTrailingDims (input, dim);
173
-
174
- const size_t element_size = input.element_size ();
175
- const size_t step = input.size (dim) * trailing_dims * element_size;
176
-
177
- const char * input_data = input.const_data_ptr <char >();
178
- for (size_t i = 0 , e = out.size (); i < e; ++i) {
179
- size_t num_bytes = out[i].size (dim) * trailing_dims * element_size;
180
- if (num_bytes == 0 ) {
181
- continue ;
182
- }
183
-
184
- const char * src = input_data;
185
- char * dest = out[i].mutable_data_ptr <char >();
186
- for (size_t j = 0 ; j < leading_dims; ++j) {
187
- memcpy (dest, src, num_bytes);
188
- src += step;
189
- dest += num_bytes;
190
- }
191
- input_data += num_bytes;
192
- }
173
+ const size_t step = input.size (dim) * trailing_dims;
174
+
175
+ ScalarType in_type = input.scalar_type ();
176
+ ScalarType out_type = out[0 ].scalar_type ();
177
+
178
+ ET_SWITCH_REAL_TYPES_AND (Bool, in_type, ctx, __func__, CTYPE_IN, [&]() {
179
+ ET_SWITCH_REAL_TYPES_AND (Bool, out_type, ctx, __func__, CTYPE_OUT, [&]() {
180
+ const CTYPE_IN* input_data = input.const_data_ptr <CTYPE_IN>();
181
+ for (size_t i = 0 , e = out.size (); i < e; ++i) {
182
+ size_t out_step = out[i].size (dim) * trailing_dims;
183
+ if (out_step == 0 ) {
184
+ continue ;
185
+ }
186
+ const CTYPE_IN* src = input_data;
187
+ CTYPE_OUT* dest = out[i].mutable_data_ptr <CTYPE_OUT>();
188
+ for (size_t j = 0 ; j < leading_dims; ++j) {
189
+ for (size_t k = 0 ; k < out_step; ++k) {
190
+ dest[k] = convert<CTYPE_OUT, CTYPE_IN>(src[k]);
191
+ }
192
+ src += step;
193
+ dest += out_step;
194
+ }
195
+ input_data += out_step;
196
+ }
197
+ });
198
+ });
193
199
}
194
200
195
201
} // namespace native
0 commit comments