diff --git a/kernels/portable/cpu/op_copy.cpp b/kernels/portable/cpu/op_copy.cpp index 19b0c3a2f6a..069800cc399 100644 --- a/kernels/portable/cpu/op_copy.cpp +++ b/kernels/portable/cpu/op_copy.cpp @@ -46,17 +46,24 @@ Tensor& copy_out( // @lint-ignore CLANGTIDY facebook-hte-CArray static constexpr const char op_name[] = "copy.out"; - ET_SWITCH_REALHBBF16_TYPES(in.scalar_type(), ctx, "copy.out", CTYPE, [&]() { - utils::apply_bitensor_elementwise_fn( - [](ET_UNUSED const CTYPE _, const CTYPE val_src) { return val_src; }, - ctx, - in, - utils::SupportedTensorDtypes::REALHBBF16, - src, - utils::SupportedTensorDtypes::REALHBBF16, - out, - utils::SupportedTensorDtypes::REALHBBF16); - }); + // Use direct copy fast path if broadcast is not needed and tensors are + // non-empty + if (internal::sizes_match_ignoring_leading_1s(out.sizes(), src.sizes()) && + src.numel() > 0) { + std::memcpy(out.mutable_data_ptr(), src.const_data_ptr(), src.nbytes()); + } else { + ET_SWITCH_REALHBBF16_TYPES(in.scalar_type(), ctx, "copy.out", CTYPE, [&]() { + utils::apply_bitensor_elementwise_fn( + [](ET_UNUSED const CTYPE _, const CTYPE val_src) { return val_src; }, + ctx, + in, + utils::SupportedTensorDtypes::REALHBBF16, + src, + utils::SupportedTensorDtypes::REALHBBF16, + out, + utils::SupportedTensorDtypes::REALHBBF16); + }); + } return out; } @@ -79,17 +86,24 @@ Tensor& copy_( // @lint-ignore CLANGTIDY facebook-hte-CArray static constexpr const char op_name[] = "copy_"; - ET_SWITCH_REALHBBF16_TYPES(in.scalar_type(), ctx, "copy_", CTYPE, [&]() { - utils::apply_bitensor_elementwise_fn( - [](ET_UNUSED const CTYPE _, const CTYPE val_src) { return val_src; }, - ctx, - in, - utils::SupportedTensorDtypes::REALHBBF16, - src, - utils::SupportedTensorDtypes::REALHBBF16, - in, - utils::SupportedTensorDtypes::REALHBBF16); - }); + // Use direct copy fast path if broadcast is not needed and tensors are + // non-empty + if (internal::sizes_match_ignoring_leading_1s(in.sizes(), src.sizes()) && + src.numel() > 0) { + std::memcpy(in.mutable_data_ptr(), src.const_data_ptr(), in.nbytes()); + } else { + ET_SWITCH_REALHBBF16_TYPES(in.scalar_type(), ctx, "copy_", CTYPE, [&]() { + utils::apply_bitensor_elementwise_fn( + [](ET_UNUSED const CTYPE _, const CTYPE val_src) { return val_src; }, + ctx, + in, + utils::SupportedTensorDtypes::REALHBBF16, + src, + utils::SupportedTensorDtypes::REALHBBF16, + in, + utils::SupportedTensorDtypes::REALHBBF16); + }); + } return in; }