@@ -17,62 +17,56 @@ namespace native {
17
17
18
18
using Tensor = exec_aten::Tensor;
19
19
20
- // copy.out(const Tensor& self , const Tensor& src, bool non_blocking, Tensor(a!)
20
+ // copy.out(const Tensor& in , const Tensor& src, bool non_blocking, Tensor(a!)
21
21
// out) -> Tensor(a!), see caffe2/aten/src/ATen/native/Copy.cpp
22
22
// TODO: We actually shouldn't see this op with the proper functionalization,
23
23
// and this op needs to be deleted
24
24
Tensor& copy_out (
25
25
RuntimeContext& ctx,
26
- const Tensor& self ,
26
+ const Tensor& in ,
27
27
const Tensor& src,
28
28
bool non_blocking,
29
29
Tensor& out) {
30
30
(void )ctx;
31
31
// Right now we only support blocking data transfer
32
- ET_CHECK ( non_blocking == false );
32
+ ET_KERNEL_CHECK (ctx, non_blocking == false , InvalidArgument, non_blocking );
33
33
34
- // The srs and out shall share same dtype, but not necessarily for self,
35
- // because `auto intermediate = src.to(self, non_blocking)` doesn't restrict
36
- // the type of self. In this kernel we didn't do `to` inside the op. If
37
- // in the short term we need self in a different type, can extend the op to
38
- // cover it.
39
- ET_CHECK_SAME_DTYPE3 (self, src, out);
34
+ ET_KERNEL_CHECK (ctx, tensors_have_same_dtype (in, out), InvalidArgument, out);
40
35
41
36
Tensor::SizesType expected_output_size[kTensorDimensionLimit ];
42
37
size_t expected_output_dim = 0 ;
43
38
44
- ET_CHECK_MSG (
45
- tensor_is_broadcastable_to (src, self),
46
- " can't broadcast from self to src " );
39
+ ET_KERNEL_CHECK (
40
+ ctx, tensor_is_broadcastable_to (src, in), InvalidArgument, src);
41
+
47
42
get_broadcast_target_size (
48
- self ,
43
+ in ,
49
44
src,
50
45
expected_output_size,
51
46
kTensorDimensionLimit ,
52
47
&expected_output_dim);
53
48
54
- ET_CHECK_MSG (
55
- Error::Ok ==
56
- resize_tensor (out, {expected_output_size, expected_output_dim}),
57
- " Failed to resize output tensor." );
58
- bool to_be_broadcasted_src = !out.sizes ().equals (src.sizes ());
49
+ ET_KERNEL_CHECK (
50
+ ctx,
51
+ resize_tensor (out, {expected_output_size, expected_output_dim}) ==
52
+ Error::Ok,
53
+ InvalidArgument,
54
+ out);
59
55
60
- // ET_CHECK_SAME_SHAPE2(expected_output_size, self);
61
- const Tensor& broadcasted_src =
62
- to_be_broadcasted_src ? broadcast_tensor (src, out) : src;
56
+ ScalarType in_type = in.scalar_type ();
57
+ ScalarType src_type = src.scalar_type ();
63
58
64
- if (broadcasted_src.nbytes () > 0 ) {
65
- // Note that this check is important. It's valid for a tensor with numel 0
66
- // to have a null data pointer, but in some environments it's invalid to
67
- // pass a null pointer to memcpy() even when the size is zero.
68
- memcpy (
69
- out.mutable_data_ptr (),
70
- broadcasted_src.const_data_ptr (),
71
- broadcasted_src.nbytes ());
72
- }
73
- if (to_be_broadcasted_src) {
74
- free_broadcast_tensor (broadcasted_src);
75
- }
59
+ ET_SWITCH_REAL_TYPES_AND (Bool, in_type, ctx, __func__, CTYPE, [&]() {
60
+ ET_SWITCH_REAL_TYPES_AND (Bool, src_type, ctx, __func__, CTYPE_SRC, [&]() {
61
+ apply_binary_elementwise_fn<CTYPE, CTYPE_SRC, CTYPE>(
62
+ [](const CTYPE val_in, const CTYPE_SRC val_src) {
63
+ return convert<CTYPE, CTYPE_SRC>(val_src);
64
+ },
65
+ in,
66
+ src,
67
+ out);
68
+ });
69
+ });
76
70
77
71
return out;
78
72
}
0 commit comments