diff --git a/kernels/portable/cpu/op_copy.cpp b/kernels/portable/cpu/op_copy.cpp index 013027d30ba..cb9e9de2dad 100644 --- a/kernels/portable/cpu/op_copy.cpp +++ b/kernels/portable/cpu/op_copy.cpp @@ -9,6 +9,7 @@ #include #include +#include #include namespace torch { @@ -42,19 +43,19 @@ Tensor& copy_out( ET_KERNEL_CHECK( ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out); - ScalarType in_type = in.scalar_type(); - ScalarType src_type = src.scalar_type(); - - ET_SWITCH_REALHBBF16_TYPES(in_type, ctx, "copy.out", CTYPE, [&]() { - ET_SWITCH_REALHBBF16_TYPES(src_type, ctx, "copy.out", CTYPE_SRC, [&]() { - apply_binary_elementwise_fn( - [](const CTYPE val_in, const CTYPE_SRC val_src) { - return convert(val_src); - }, - in, - src, - out); - }); + // @lint-ignore CLANGTIDY facebook-hte-CArray + static constexpr const char op_name[] = "copy.out"; + + ET_SWITCH_REALHBBF16_TYPES(in.scalar_type(), ctx, "copy.out", CTYPE, [&]() { + utils::apply_bitensor_elementwise_fn( + [](ET_UNUSED const CTYPE _, const CTYPE val_src) { return val_src; }, + ctx, + in, + utils::SupportedTensorDtypes::REALHBBF16, + src, + utils::SupportedTensorDtypes::REALHBBF16, + out, + utils::SupportedTensorDtypes::REALHBBF16); }); return out; @@ -75,19 +76,19 @@ Tensor& copy_( ET_KERNEL_CHECK( ctx, tensors_have_same_dim_order(in, src), InvalidArgument, in); - ScalarType in_type = in.scalar_type(); - ScalarType src_type = src.scalar_type(); - - ET_SWITCH_REALHBBF16_TYPES(in_type, ctx, "copy_", CTYPE, [&]() { - ET_SWITCH_REALHBBF16_TYPES(src_type, ctx, "copy_", CTYPE_SRC, [&]() { - apply_binary_elementwise_fn( - [](const CTYPE val_in, const CTYPE_SRC val_src) { - return convert(val_src); - }, - in, - src, - in); - }); + // @lint-ignore CLANGTIDY facebook-hte-CArray + static constexpr const char op_name[] = "copy_"; + + ET_SWITCH_REALHBBF16_TYPES(in.scalar_type(), ctx, "copy_", CTYPE, [&]() { + utils::apply_bitensor_elementwise_fn( + [](ET_UNUSED const CTYPE _, const CTYPE val_src) { return val_src; }, + ctx, + in, + utils::SupportedTensorDtypes::REALHBBF16, + src, + utils::SupportedTensorDtypes::REALHBBF16, + in, + utils::SupportedTensorDtypes::REALHBBF16); }); return in; diff --git a/shim/xplat/executorch/kernels/portable/op_registration_util.bzl b/shim/xplat/executorch/kernels/portable/op_registration_util.bzl index 931a3be4ad1..538e20bd126 100644 --- a/shim/xplat/executorch/kernels/portable/op_registration_util.bzl +++ b/shim/xplat/executorch/kernels/portable/op_registration_util.bzl @@ -425,6 +425,8 @@ ATEN_OPS = ( name = "op_copy", deps = [ "//executorch/kernels/portable/cpu/util:broadcast_util", + "//executorch/kernels/portable/cpu/util:dtype_util", + "//executorch/kernels/portable/cpu/util:elementwise_util", "//executorch/runtime/core/exec_aten/util:scalar_type_util", "//executorch/runtime/core/exec_aten/util:tensor_util", ":scalar_utils",