@@ -41,13 +41,13 @@ void check_args(const Tensor& input, int64_t dim, TensorList out) {
4141
4242 // Validate each output.
4343 for (size_t i = 0 ; i < out.size (); ++i) {
44- // All output dtypes must match the input type .
44+ // All output dtypes must be the same .
4545 ET_CHECK_MSG (
46- out[i].scalar_type () == input .scalar_type (),
47- " out[%zu] dtype %hhd != input dtype %hhd" ,
46+ out[i].scalar_type () == out[ 0 ] .scalar_type (),
47+ " out[%zu] dtype %hhd != out[0] dtype %hhd" ,
4848 i,
4949 out[i].scalar_type (),
50- input .scalar_type ());
50+ out[ 0 ] .scalar_type ());
5151
5252 // output tensor must have # of dims = input.dim() -1
5353 ET_CHECK_MSG (
@@ -97,25 +97,29 @@ void unbind_copy_int_out(
9797
9898 const size_t leading_dims = getLeadingDims (input, dim);
9999 const size_t trailing_dims = getTrailingDims (input, dim);
100-
101- const size_t element_size = input.element_size ();
102- const size_t step = input.size (dim) * trailing_dims * element_size;
103-
104- const char * input_data = input.const_data_ptr <char >();
105- for (size_t i = 0 , e = out.size (); i < e; ++i) {
106- size_t num_bytes = trailing_dims * element_size;
107- // num_bytes should not be zero because trailing_dims
108- // will at least return 1
109-
110- const char * src = input_data;
111- char * dest = out[i].mutable_data_ptr <char >();
112- for (size_t j = 0 ; j < leading_dims; ++j) {
113- memcpy (dest, src, num_bytes);
114- src += step;
115- dest += num_bytes;
116- }
117- input_data += num_bytes;
118- }
100+ const size_t step = input.size (dim) * trailing_dims;
101+
102+ ScalarType in_type = input.scalar_type ();
103+ ScalarType out_type = out[0 ].scalar_type ();
104+
105+ ET_SWITCH_REAL_TYPES_AND (Bool, in_type, ctx, __func__, CTYPE_IN, [&]() {
106+ ET_SWITCH_REAL_TYPES_AND (Bool, out_type, ctx, __func__, CTYPE_OUT, [&]() {
107+ const CTYPE_IN* const input_data = input.const_data_ptr <CTYPE_IN>();
108+ for (size_t i = 0 , e = out.size (); i < e; ++i) {
109+ size_t input_offset = i * trailing_dims;
110+ CTYPE_OUT* const dest = out[i].mutable_data_ptr <CTYPE_OUT>();
111+ size_t dest_offset = 0 ;
112+ for (size_t j = 0 ; j < leading_dims; ++j) {
113+ for (size_t k = 0 ; k < trailing_dims; ++k) {
114+ dest[dest_offset + k] =
115+ convert<CTYPE_OUT, CTYPE_IN>(input_data[input_offset + k]);
116+ }
117+ input_offset += step;
118+ dest_offset += trailing_dims;
119+ }
120+ }
121+ });
122+ });
119123}
120124
121125} // namespace native
0 commit comments