diff --git a/kernels/portable/cpu/op_pixel_shuffle.cpp b/kernels/portable/cpu/op_pixel_shuffle.cpp index a191048386b..0af790d985e 100644 --- a/kernels/portable/cpu/op_pixel_shuffle.cpp +++ b/kernels/portable/cpu/op_pixel_shuffle.cpp @@ -14,10 +14,11 @@ namespace executor { namespace native { namespace { -template void pixel_shuffle_impl(const Tensor& in, int64_t upscale_factor, Tensor& out) { - const CTYPE* const in_data = in.const_data_ptr(); - CTYPE* const out_data = out.mutable_data_ptr(); + const char* const in_data = + reinterpret_cast(in.const_data_ptr()); + char* const out_data = reinterpret_cast(out.mutable_data_ptr()); + const auto elem_size = in.element_size(); const auto leading_dims = getLeadingDims(in, in.dim() - 3); const auto channels = in.size(in.dim() - 3); @@ -45,7 +46,11 @@ void pixel_shuffle_impl(const Tensor& in, int64_t upscale_factor, Tensor& out) { for (size_t s2 = 0; s2 < S; s2++) { size_t input_offset = n * stride_n + c * stride_c + s1 * stride_s1 + s2 * stride_s2 + h * stride_h + w; - out_data[i++] = in_data[input_offset]; + std::memcpy( + out_data + i * elem_size, + in_data + input_offset * elem_size, + elem_size); + i++; } } } @@ -88,13 +93,7 @@ Tensor& pixel_shuffle_out( InvalidArgument, out); - constexpr auto name = "pixel_shuffle.out"; - - const auto in_type = out.scalar_type(); - // in and out must be the same dtype - ET_SWITCH_ALL_TYPES(in_type, ctx, name, CTYPE, [&]() { - pixel_shuffle_impl(in, upscale_factor, out); - }); + pixel_shuffle_impl(in, upscale_factor, out); return out; } diff --git a/kernels/portable/cpu/op_pixel_unshuffle.cpp b/kernels/portable/cpu/op_pixel_unshuffle.cpp index 0cbc9756d92..f12a2e97aed 100644 --- a/kernels/portable/cpu/op_pixel_unshuffle.cpp +++ b/kernels/portable/cpu/op_pixel_unshuffle.cpp @@ -14,13 +14,14 @@ namespace executor { namespace native { namespace { -template void pixel_unshuffle_impl( const Tensor& in, int64_t downscale_factor, Tensor& out) { - const CTYPE* const in_data = in.const_data_ptr(); - CTYPE* const out_data = out.mutable_data_ptr(); + const char* const in_data = + reinterpret_cast(in.const_data_ptr()); + char* const out_data = reinterpret_cast(out.mutable_data_ptr()); + const auto elem_size = in.element_size(); const auto leading_dims = getLeadingDims(in, in.dim() - 3); const auto channels = out.size(in.dim() - 3); @@ -48,7 +49,11 @@ void pixel_unshuffle_impl( for (size_t s2 = 0; s2 < S; s2++) { size_t output_offset = n * stride_n + c * stride_c + s1 * stride_s1 + s2 * stride_s2 + h * stride_h + w; - out_data[output_offset] = in_data[i++]; + std::memcpy( + out_data + output_offset * elem_size, + in_data + i * elem_size, + elem_size); + i++; } } } @@ -88,13 +93,7 @@ Tensor& pixel_unshuffle_out( InvalidArgument, out); - constexpr auto name = "pixel_unshuffle.out"; - - const auto in_type = out.scalar_type(); - // in and out must be the same dtype - ET_SWITCH_ALL_TYPES(in_type, ctx, name, CTYPE, [&]() { - pixel_unshuffle_impl(in, downscale_factor, out); - }); + pixel_unshuffle_impl(in, downscale_factor, out); return out; }