66 * LICENSE file in the root directory of this source tree.
77 */
88
9- #include < cstring>
10-
11- #include < executorch/kernels/portable/cpu/util/broadcast_util.h>
12- #include < executorch/kernels/portable/cpu/util/elementwise_util.h>
13- #include < executorch/runtime/kernel/kernel_includes.h>
14-
15- namespace torch {
16- namespace executor {
17- namespace native {
18-
19- using Tensor = executorch::aten::Tensor;
20-
21- // copy.out(const Tensor& in, const Tensor& src, bool non_blocking, Tensor(a!)
22- // out) -> Tensor(a!), see caffe2/aten/src/ATen/native/Copy.cpp
23- // TODO: We actually shouldn't see this op with the proper functionalization,
24- // and this op needs to be deleted
25- Tensor& copy_out (
26- KernelRuntimeContext& ctx,
27- const Tensor& in,
28- const Tensor& src,
29- bool non_blocking,
30- Tensor& out) {
31- (void )ctx;
32- // Right now we only support blocking data transfer
33- ET_KERNEL_CHECK (ctx, non_blocking == false , InvalidArgument, out);
34-
35- ET_KERNEL_CHECK (ctx, tensors_have_same_dtype (in, out), InvalidArgument, out);
36-
37- ET_KERNEL_CHECK (
38- ctx, tensor_is_broadcastable_to (src, in), InvalidArgument, out);
39-
40- ET_KERNEL_CHECK (
41- ctx, resize_tensor (out, in.sizes ()) == Error::Ok, InvalidArgument, out);
42-
43- ET_KERNEL_CHECK (
44- ctx, tensors_have_same_dim_order (in, out), InvalidArgument, out);
45-
46- // @lint-ignore CLANGTIDY facebook-hte-CArray
47- static constexpr const char op_name[] = " copy.out" ;
48-
49- ET_SWITCH_REALHBBF16_TYPES (in.scalar_type (), ctx, " copy.out" , CTYPE, [&]() {
50- utils::apply_bitensor_elementwise_fn<CTYPE, op_name>(
51- [](ET_UNUSED const CTYPE _, const CTYPE val_src) { return val_src; },
52- ctx,
53- in,
54- utils::SupportedTensorDtypes::REALHBBF16,
55- src,
56- utils::SupportedTensorDtypes::REALHBBF16,
57- out,
58- utils::SupportedTensorDtypes::REALHBBF16);
59- });
60-
61- return out;
62- }
63-
64- Tensor& copy_ (
65- KernelRuntimeContext& ctx,
66- Tensor& in,
67- const Tensor& src,
68- bool non_blocking) {
69- (void )ctx;
70- // Right now we only support blocking data transfer
71- ET_KERNEL_CHECK (ctx, non_blocking == false , InvalidArgument, in);
72-
73- ET_KERNEL_CHECK (
74- ctx, tensor_is_broadcastable_to (src, in), InvalidArgument, in);
75-
76- ET_KERNEL_CHECK (
77- ctx, tensors_have_same_dim_order (in, src), InvalidArgument, in);
78-
79- // @lint-ignore CLANGTIDY facebook-hte-CArray
80- static constexpr const char op_name[] = " copy_" ;
81-
82- ET_SWITCH_REALHBBF16_TYPES (in.scalar_type (), ctx, " copy_" , CTYPE, [&]() {
83- utils::apply_bitensor_elementwise_fn<CTYPE, op_name>(
84- [](ET_UNUSED const CTYPE _, const CTYPE val_src) { return val_src; },
85- ctx,
86- in,
87- utils::SupportedTensorDtypes::REALHBBF16,
88- src,
89- utils::SupportedTensorDtypes::REALHBBF16,
90- in,
91- utils::SupportedTensorDtypes::REALHBBF16);
92- });
93-
94- return in;
95- }
96-
97- } // namespace native
98- } // namespace executor
99- } // namespace torch
9+ #include < cstring>
10+
11+ #include < executorch/kernels/portable/cpu/util/broadcast_util.h>
12+ #include < executorch/kernels/portable/cpu/util/elementwise_util.h>
13+ #include < executorch/runtime/kernel/kernel_includes.h>
14+
15+ namespace torch {
16+ namespace executor {
17+ namespace native {
18+
19+ using Tensor = executorch::aten::Tensor;
20+
21+ // copy.out(const Tensor& in, const Tensor& src, bool non_blocking, Tensor(a!)
22+ // out) -> Tensor(a!), see caffe2/aten/src/ATen/native/Copy.cpp
23+ // TODO: We actually shouldn't see this op with the proper functionalization,
24+ // and this op needs to be deleted
25+ Tensor& copy_out (
26+ KernelRuntimeContext& ctx,
27+ const Tensor& in,
28+ const Tensor& src,
29+ bool non_blocking,
30+ Tensor& out) {
31+ (void )ctx;
32+ // Right now we only support blocking data transfer
33+ ET_KERNEL_CHECK (ctx, non_blocking == false , InvalidArgument, out);
34+
35+ ET_KERNEL_CHECK (ctx, tensors_have_same_dtype (in, out), InvalidArgument, out);
36+
37+ ET_KERNEL_CHECK (
38+ ctx, tensor_is_broadcastable_to (src, in), InvalidArgument, out);
39+
40+ ET_KERNEL_CHECK (
41+ ctx, resize_tensor (out, in.sizes ()) == Error::Ok, InvalidArgument, out);
42+
43+ ET_KERNEL_CHECK (
44+ ctx, tensors_have_same_dim_order (in, out), InvalidArgument, out);
45+
46+ // @lint-ignore CLANGTIDY facebook-hte-CArray
47+ static constexpr const char op_name[] = " copy.out" ;
48+
49+ // Use direct copy fast path if broadcast is not needed and tensors are
50+ // non-empty
51+ if (internal::sizes_match_ignoring_leading_1s (out.sizes (), src.sizes ()) &&
52+ src.numel () > 0 ) {
53+ std::memcpy (out.mutable_data_ptr (), src.const_data_ptr (), src.nbytes ());
54+ } else {
55+ ET_SWITCH_REALHBBF16_TYPES (in.scalar_type (), ctx, " copy.out" , CTYPE, [&]() {
56+ utils::apply_bitensor_elementwise_fn<CTYPE, op_name>(
57+ [](ET_UNUSED const CTYPE _, const CTYPE val_src) { return val_src; },
58+ ctx,
59+ in,
60+ utils::SupportedTensorDtypes::REALHBBF16,
61+ src,
62+ utils::SupportedTensorDtypes::REALHBBF16,
63+ out,
64+ utils::SupportedTensorDtypes::REALHBBF16);
65+ });
66+ }
67+
68+ return out;
69+ }
70+
71+ Tensor& copy_ (
72+ KernelRuntimeContext& ctx,
73+ Tensor& in,
74+ const Tensor& src,
75+ bool non_blocking) {
76+ (void )ctx;
77+ // Right now we only support blocking data transfer
78+ ET_KERNEL_CHECK (ctx, non_blocking == false , InvalidArgument, in);
79+
80+ ET_KERNEL_CHECK (
81+ ctx, tensor_is_broadcastable_to (src, in), InvalidArgument, in);
82+
83+ ET_KERNEL_CHECK (
84+ ctx, tensors_have_same_dim_order (in, src), InvalidArgument, in);
85+
86+ // @lint-ignore CLANGTIDY facebook-hte-CArray
87+ static constexpr const char op_name[] = " copy_" ;
88+
89+ // Use direct copy fast path if broadcast is not needed and tensors are
90+ // non-empty
91+ if (internal::sizes_match_ignoring_leading_1s (in.sizes (), src.sizes ()) &&
92+ src.numel () > 0 ) {
93+ std::memcpy (in.mutable_data_ptr (), src.const_data_ptr (), in.nbytes ());
94+ } else {
95+ ET_SWITCH_REALHBBF16_TYPES (in.scalar_type (), ctx, " copy_" , CTYPE, [&]() {
96+ utils::apply_bitensor_elementwise_fn<CTYPE, op_name>(
97+ [](ET_UNUSED const CTYPE _, const CTYPE val_src) { return val_src; },
98+ ctx,
99+ in,
100+ utils::SupportedTensorDtypes::REALHBBF16,
101+ src,
102+ utils::SupportedTensorDtypes::REALHBBF16,
103+ in,
104+ utils::SupportedTensorDtypes::REALHBBF16);
105+ });
106+ }
107+
108+ return in;
109+ }
110+
111+ } // namespace native
112+ } // namespace executor
113+ } // namespace torch
114+
0 commit comments