@@ -87,13 +87,13 @@ void check_args(
8787
8888 // Validate each output.
8989 for (size_t i = 0 ; i < out.size (); ++i) {
90- // All output dtypes must match the input type .
90+ // All output dtypes must be the same .
9191 ET_CHECK_MSG (
92- out[i].scalar_type () == input .scalar_type (),
93- " out[%zu] dtype %hhd != input dtype %hhd" ,
92+ out[i].scalar_type () == out[ 0 ] .scalar_type (),
93+ " out[%zu] dtype %hhd != out[0] dtype %hhd" ,
9494 i,
9595 out[i].scalar_type (),
96- input .scalar_type ());
96+ out[ 0 ] .scalar_type ());
9797
9898 // All outputs must have the same number of dimensions as the input.
9999 ET_CHECK_MSG (
@@ -170,26 +170,32 @@ void split_copy_Tensor_out(
170170
171171 const size_t leading_dims = getLeadingDims (input, dim);
172172 const size_t trailing_dims = getTrailingDims (input, dim);
173-
174- const size_t element_size = input.element_size ();
175- const size_t step = input.size (dim) * trailing_dims * element_size;
176-
177- const char * input_data = input.const_data_ptr <char >();
178- for (size_t i = 0 , e = out.size (); i < e; ++i) {
179- size_t num_bytes = out[i].size (dim) * trailing_dims * element_size;
180- if (num_bytes == 0 ) {
181- continue ;
182- }
183-
184- const char * src = input_data;
185- char * dest = out[i].mutable_data_ptr <char >();
186- for (size_t j = 0 ; j < leading_dims; ++j) {
187- memcpy (dest, src, num_bytes);
188- src += step;
189- dest += num_bytes;
190- }
191- input_data += num_bytes;
192- }
173+ const size_t step = input.size (dim) * trailing_dims;
174+
175+ ScalarType in_type = input.scalar_type ();
176+ ScalarType out_type = out[0 ].scalar_type ();
177+
178+ ET_SWITCH_REAL_TYPES_AND (Bool, in_type, ctx, __func__, CTYPE_IN, [&]() {
179+ ET_SWITCH_REAL_TYPES_AND (Bool, out_type, ctx, __func__, CTYPE_OUT, [&]() {
180+ const CTYPE_IN* input_data = input.const_data_ptr <CTYPE_IN>();
181+ for (size_t i = 0 , e = out.size (); i < e; ++i) {
182+ size_t out_step = out[i].size (dim) * trailing_dims;
183+ if (out_step == 0 ) {
184+ continue ;
185+ }
186+ const CTYPE_IN* src = input_data;
187+ CTYPE_OUT* dest = out[i].mutable_data_ptr <CTYPE_OUT>();
188+ for (size_t j = 0 ; j < leading_dims; ++j) {
189+ for (size_t k = 0 ; k < out_step; ++k) {
190+ dest[k] = convert<CTYPE_OUT, CTYPE_IN>(src[k]);
191+ }
192+ src += step;
193+ dest += out_step;
194+ }
195+ input_data += out_step;
196+ }
197+ });
198+ });
193199}
194200
195201} // namespace native
0 commit comments