@@ -14,10 +14,11 @@ namespace executor {
1414namespace native {
1515namespace {
1616
17- template <typename CTYPE>
1817void pixel_shuffle_impl (const Tensor& in, int64_t upscale_factor, Tensor& out) {
19- const CTYPE* const in_data = in.const_data_ptr <CTYPE>();
20- CTYPE* const out_data = out.mutable_data_ptr <CTYPE>();
18+ const char * const in_data =
19+ reinterpret_cast <const char *>(in.const_data_ptr ());
20+ char * const out_data = reinterpret_cast <char *>(out.mutable_data_ptr ());
21+ const auto elem_size = in.element_size ();
2122
2223 const auto leading_dims = getLeadingDims (in, in.dim () - 3 );
2324 const auto channels = in.size (in.dim () - 3 );
@@ -45,7 +46,11 @@ void pixel_shuffle_impl(const Tensor& in, int64_t upscale_factor, Tensor& out) {
4546 for (size_t s2 = 0 ; s2 < S; s2++) {
4647 size_t input_offset = n * stride_n + c * stride_c +
4748 s1 * stride_s1 + s2 * stride_s2 + h * stride_h + w;
48- out_data[i++] = in_data[input_offset];
49+ std::memcpy (
50+ out_data + i * elem_size,
51+ in_data + input_offset * elem_size,
52+ elem_size);
53+ i++;
4954 }
5055 }
5156 }
@@ -88,13 +93,7 @@ Tensor& pixel_shuffle_out(
8893 InvalidArgument,
8994 out);
9095
91- constexpr auto name = " pixel_shuffle.out" ;
92-
93- const auto in_type = out.scalar_type ();
94- // in and out must be the same dtype
95- ET_SWITCH_ALL_TYPES (in_type, ctx, name, CTYPE, [&]() {
96- pixel_shuffle_impl<CTYPE>(in, upscale_factor, out);
97- });
96+ pixel_shuffle_impl (in, upscale_factor, out);
9897
9998 return out;
10099}
0 commit comments