99#include < cstring>
1010
1111#include < executorch/kernels/portable/cpu/util/broadcast_util.h>
12+ #include < executorch/kernels/portable/cpu/util/elementwise_util.h>
1213#include < executorch/runtime/kernel/kernel_includes.h>
1314
1415namespace torch {
@@ -42,19 +43,18 @@ Tensor& copy_out(
4243 ET_KERNEL_CHECK (
4344 ctx, tensors_have_same_dim_order (in, out), InvalidArgument, out);
4445
45- ScalarType in_type = in.scalar_type ();
46- ScalarType src_type = src.scalar_type ();
47-
48- ET_SWITCH_REALHBBF16_TYPES (in_type, ctx, " copy.out" , CTYPE, [&]() {
49- ET_SWITCH_REALHBBF16_TYPES (src_type, ctx, " copy.out" , CTYPE_SRC, [&]() {
50- apply_binary_elementwise_fn<CTYPE, CTYPE_SRC, CTYPE>(
51- [](const CTYPE val_in, const CTYPE_SRC val_src) {
52- return convert<CTYPE, CTYPE_SRC>(val_src);
53- },
54- in,
55- src,
56- out);
57- });
46+ static constexpr const char op_name[] = " copy.out" ;
47+
48+ ET_SWITCH_REALHBBF16_TYPES (in.scalar_type (), ctx, " copy.out" , CTYPE, [&]() {
49+ utils::apply_bitensor_elementwise_fn<CTYPE, op_name>(
50+ [](const CTYPE _, const CTYPE val_src) { return val_src; },
51+ ctx,
52+ in,
53+ utils::SupportedTensorDtypes::REALHBBF16,
54+ src,
55+ utils::SupportedTensorDtypes::REALHBBF16,
56+ out,
57+ utils::SupportedTensorDtypes::REALHBBF16);
5858 });
5959
6060 return out;
@@ -75,19 +75,18 @@ Tensor& copy_(
7575 ET_KERNEL_CHECK (
7676 ctx, tensors_have_same_dim_order (in, src), InvalidArgument, in);
7777
78- ScalarType in_type = in.scalar_type ();
79- ScalarType src_type = src.scalar_type ();
80-
81- ET_SWITCH_REALHBBF16_TYPES (in_type, ctx, " copy_" , CTYPE, [&]() {
82- ET_SWITCH_REALHBBF16_TYPES (src_type, ctx, " copy_" , CTYPE_SRC, [&]() {
83- apply_binary_elementwise_fn<CTYPE, CTYPE_SRC, CTYPE>(
84- [](const CTYPE val_in, const CTYPE_SRC val_src) {
85- return convert<CTYPE, CTYPE_SRC>(val_src);
86- },
87- in,
88- src,
89- in);
90- });
78+ static constexpr const char op_name[] = " copy_" ;
79+
80+ ET_SWITCH_REALHBBF16_TYPES (in.scalar_type (), ctx, " copy_" , CTYPE, [&]() {
81+ utils::apply_bitensor_elementwise_fn<CTYPE, op_name>(
82+ [](const CTYPE _, const CTYPE val_src) { return val_src; },
83+ ctx,
84+ in,
85+ utils::SupportedTensorDtypes::REALHBBF16,
86+ src,
87+ utils::SupportedTensorDtypes::REALHBBF16,
88+ in,
89+ utils::SupportedTensorDtypes::REALHBBF16);
9190 });
9291
9392 return in;
0 commit comments