@@ -46,17 +46,24 @@ Tensor& copy_out(
4646 // @lint-ignore CLANGTIDY facebook-hte-CArray
4747 static constexpr const char op_name[] = " copy.out" ;
4848
49- ET_SWITCH_REALHBBF16_TYPES (in.scalar_type (), ctx, " copy.out" , CTYPE, [&]() {
50- utils::apply_bitensor_elementwise_fn<CTYPE, op_name>(
51- [](ET_UNUSED const CTYPE _, const CTYPE val_src) { return val_src; },
52- ctx,
53- in,
54- utils::SupportedTensorDtypes::REALHBBF16,
55- src,
56- utils::SupportedTensorDtypes::REALHBBF16,
57- out,
58- utils::SupportedTensorDtypes::REALHBBF16);
59- });
49+ // Use direct copy fast path if broadcast is not needed and tensors are
50+ // non-empty
51+ if (internal::sizes_match_ignoring_leading_1s (out.sizes (), src.sizes ()) &&
52+ src.numel () > 0 ) {
53+ std::memcpy (out.mutable_data_ptr (), src.const_data_ptr (), src.nbytes ());
54+ } else {
55+ ET_SWITCH_REALHBBF16_TYPES (in.scalar_type (), ctx, " copy.out" , CTYPE, [&]() {
56+ utils::apply_bitensor_elementwise_fn<CTYPE, op_name>(
57+ [](ET_UNUSED const CTYPE _, const CTYPE val_src) { return val_src; },
58+ ctx,
59+ in,
60+ utils::SupportedTensorDtypes::REALHBBF16,
61+ src,
62+ utils::SupportedTensorDtypes::REALHBBF16,
63+ out,
64+ utils::SupportedTensorDtypes::REALHBBF16);
65+ });
66+ }
6067
6168 return out;
6269}
@@ -79,17 +86,24 @@ Tensor& copy_(
7986 // @lint-ignore CLANGTIDY facebook-hte-CArray
8087 static constexpr const char op_name[] = " copy_" ;
8188
82- ET_SWITCH_REALHBBF16_TYPES (in.scalar_type (), ctx, " copy_" , CTYPE, [&]() {
83- utils::apply_bitensor_elementwise_fn<CTYPE, op_name>(
84- [](ET_UNUSED const CTYPE _, const CTYPE val_src) { return val_src; },
85- ctx,
86- in,
87- utils::SupportedTensorDtypes::REALHBBF16,
88- src,
89- utils::SupportedTensorDtypes::REALHBBF16,
90- in,
91- utils::SupportedTensorDtypes::REALHBBF16);
92- });
89+ // Use direct copy fast path if broadcast is not needed and tensors are
90+ // non-empty
91+ if (internal::sizes_match_ignoring_leading_1s (in.sizes (), src.sizes ()) &&
92+ src.numel () > 0 ) {
93+ std::memcpy (in.mutable_data_ptr (), src.const_data_ptr (), in.nbytes ());
94+ } else {
95+ ET_SWITCH_REALHBBF16_TYPES (in.scalar_type (), ctx, " copy_" , CTYPE, [&]() {
96+ utils::apply_bitensor_elementwise_fn<CTYPE, op_name>(
97+ [](ET_UNUSED const CTYPE _, const CTYPE val_src) { return val_src; },
98+ ctx,
99+ in,
100+ utils::SupportedTensorDtypes::REALHBBF16,
101+ src,
102+ utils::SupportedTensorDtypes::REALHBBF16,
103+ in,
104+ utils::SupportedTensorDtypes::REALHBBF16);
105+ });
106+ }
93107
94108 return in;
95109}
0 commit comments