@@ -41,13 +41,13 @@ void check_args(const Tensor& input, int64_t dim, TensorList out) {
41
41
42
42
// Validate each output.
43
43
for (size_t i = 0 ; i < out.size (); ++i) {
44
- // All output dtypes must match the input type .
44
+ // All output dtypes must be the same .
45
45
ET_CHECK_MSG (
46
- out[i].scalar_type () == input .scalar_type (),
47
- " out[%zu] dtype %hhd != input dtype %hhd" ,
46
+ out[i].scalar_type () == out[ 0 ] .scalar_type (),
47
+ " out[%zu] dtype %hhd != out[0] dtype %hhd" ,
48
48
i,
49
49
out[i].scalar_type (),
50
- input .scalar_type ());
50
+ out[ 0 ] .scalar_type ());
51
51
52
52
// output tensor must have # of dims = input.dim() -1
53
53
ET_CHECK_MSG (
@@ -97,25 +97,29 @@ void unbind_copy_int_out(
97
97
98
98
const size_t leading_dims = getLeadingDims (input, dim);
99
99
const size_t trailing_dims = getTrailingDims (input, dim);
100
-
101
- const size_t element_size = input.element_size ();
102
- const size_t step = input.size (dim) * trailing_dims * element_size;
103
-
104
- const char * input_data = input.const_data_ptr <char >();
105
- for (size_t i = 0 , e = out.size (); i < e; ++i) {
106
- size_t num_bytes = trailing_dims * element_size;
107
- // num_bytes should not be zero because trailing_dims
108
- // will at least return 1
109
-
110
- const char * src = input_data;
111
- char * dest = out[i].mutable_data_ptr <char >();
112
- for (size_t j = 0 ; j < leading_dims; ++j) {
113
- memcpy (dest, src, num_bytes);
114
- src += step;
115
- dest += num_bytes;
116
- }
117
- input_data += num_bytes;
118
- }
100
+ const size_t step = input.size (dim) * trailing_dims;
101
+
102
+ ScalarType in_type = input.scalar_type ();
103
+ ScalarType out_type = out[0 ].scalar_type ();
104
+
105
+ ET_SWITCH_REAL_TYPES_AND (Bool, in_type, ctx, __func__, CTYPE_IN, [&]() {
106
+ ET_SWITCH_REAL_TYPES_AND (Bool, out_type, ctx, __func__, CTYPE_OUT, [&]() {
107
+ const CTYPE_IN* const input_data = input.const_data_ptr <CTYPE_IN>();
108
+ for (size_t i = 0 , e = out.size (); i < e; ++i) {
109
+ size_t input_offset = i * trailing_dims;
110
+ CTYPE_OUT* const dest = out[i].mutable_data_ptr <CTYPE_OUT>();
111
+ size_t dest_offset = 0 ;
112
+ for (size_t j = 0 ; j < leading_dims; ++j) {
113
+ for (size_t k = 0 ; k < trailing_dims; ++k) {
114
+ dest[dest_offset + k] =
115
+ convert<CTYPE_OUT, CTYPE_IN>(input_data[input_offset + k]);
116
+ }
117
+ input_offset += step;
118
+ dest_offset += trailing_dims;
119
+ }
120
+ }
121
+ });
122
+ });
119
123
}
120
124
121
125
} // namespace native
0 commit comments