99#include < cstring>
1010
1111#include < executorch/kernels/portable/cpu/util/broadcast_util.h>
12+ #include < executorch/kernels/portable/cpu/util/elementwise_util.h>
1213#include < executorch/runtime/kernel/kernel_includes.h>
1314
1415namespace torch {
@@ -42,19 +43,19 @@ Tensor& copy_out(
4243 ET_KERNEL_CHECK (
4344 ctx, tensors_have_same_dim_order (in, out), InvalidArgument, out);
4445
45- ScalarType in_type = in. scalar_type ();
46- ScalarType src_type = src. scalar_type () ;
47-
48- ET_SWITCH_REALHBBF16_TYPES (in_type , ctx, " copy.out" , CTYPE, [&]() {
49- ET_SWITCH_REALHBBF16_TYPES (src_type, ctx, " copy.out " , CTYPE_SRC, [&]() {
50- apply_binary_elementwise_fn< CTYPE, CTYPE_SRC, CTYPE>(
51- []( const CTYPE val_in, const CTYPE_SRC val_src) {
52- return convert<CTYPE, CTYPE_SRC>(val_src);
53- } ,
54- in ,
55- src ,
56- out);
57- } );
46+ // @lint-ignore CLANGTIDY facebook-hte-CArray
47+ static constexpr const char op_name[] = " copy.out " ;
48+
49+ ET_SWITCH_REALHBBF16_TYPES (in. scalar_type () , ctx, " copy.out" , CTYPE, [&]() {
50+ utils::apply_bitensor_elementwise_fn<CTYPE, op_name>(
51+ []( const CTYPE _, const CTYPE val_src) { return val_src; },
52+ ctx,
53+ in,
54+ utils::SupportedTensorDtypes::REALHBBF16 ,
55+ src ,
56+ utils::SupportedTensorDtypes::REALHBBF16 ,
57+ out,
58+ utils::SupportedTensorDtypes::REALHBBF16 );
5859 });
5960
6061 return out;
@@ -75,19 +76,18 @@ Tensor& copy_(
7576 ET_KERNEL_CHECK (
7677 ctx, tensors_have_same_dim_order (in, src), InvalidArgument, in);
7778
78- ScalarType in_type = in.scalar_type ();
79- ScalarType src_type = src.scalar_type ();
80-
81- ET_SWITCH_REALHBBF16_TYPES (in_type, ctx, " copy_" , CTYPE, [&]() {
82- ET_SWITCH_REALHBBF16_TYPES (src_type, ctx, " copy_" , CTYPE_SRC, [&]() {
83- apply_binary_elementwise_fn<CTYPE, CTYPE_SRC, CTYPE>(
84- [](const CTYPE val_in, const CTYPE_SRC val_src) {
85- return convert<CTYPE, CTYPE_SRC>(val_src);
86- },
87- in,
88- src,
89- in);
90- });
79+ static constexpr const char op_name[] = " copy_" ;
80+
81+ ET_SWITCH_REALHBBF16_TYPES (in.scalar_type (), ctx, " copy_" , CTYPE, [&]() {
82+ utils::apply_bitensor_elementwise_fn<CTYPE, op_name>(
83+ [](const CTYPE _, const CTYPE val_src) { return val_src; },
84+ ctx,
85+ in,
86+ utils::SupportedTensorDtypes::REALHBBF16,
87+ src,
88+ utils::SupportedTensorDtypes::REALHBBF16,
89+ in,
90+ utils::SupportedTensorDtypes::REALHBBF16);
9191 });
9292
9393 return in;
0 commit comments