Skip to content

Commit a82d03c

Browse files
manuelcandalesfacebook-github-bot
authored andcommitted
Dtype compliance: copy
Reviewed By: SS-JIA Differential Revision: D48289832 fbshipit-source-id: f5f1ca118adb0fb4dc15096ebe59e26016f6bb1a
1 parent af67e28 commit a82d03c

File tree

2 files changed

+29
-38
lines changed

2 files changed

+29
-38
lines changed

kernels/portable/cpu/op_copy.cpp

Lines changed: 27 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -17,62 +17,56 @@ namespace native {
1717

1818
using Tensor = exec_aten::Tensor;
1919

20-
// copy.out(const Tensor& self, const Tensor& src, bool non_blocking, Tensor(a!)
20+
// copy.out(const Tensor& in, const Tensor& src, bool non_blocking, Tensor(a!)
2121
// out) -> Tensor(a!), see caffe2/aten/src/ATen/native/Copy.cpp
2222
// TODO: We actually shouldn't see this op with the proper functionalization,
2323
// and this op needs to be deleted
2424
Tensor& copy_out(
2525
RuntimeContext& ctx,
26-
const Tensor& self,
26+
const Tensor& in,
2727
const Tensor& src,
2828
bool non_blocking,
2929
Tensor& out) {
3030
(void)ctx;
3131
// Right now we only support blocking data transfer
32-
ET_CHECK(non_blocking == false);
32+
ET_KERNEL_CHECK(ctx, non_blocking == false, InvalidArgument, non_blocking);
3333

34-
// The srs and out shall share same dtype, but not necessarily for self,
35-
// because `auto intermediate = src.to(self, non_blocking)` doesn't restrict
36-
// the type of self. In this kernel we didn't do `to` inside the op. If
37-
// in the short term we need self in a different type, can extend the op to
38-
// cover it.
39-
ET_CHECK_SAME_DTYPE3(self, src, out);
34+
ET_KERNEL_CHECK(ctx, tensors_have_same_dtype(in, out), InvalidArgument, out);
4035

4136
Tensor::SizesType expected_output_size[kTensorDimensionLimit];
4237
size_t expected_output_dim = 0;
4338

44-
ET_CHECK_MSG(
45-
tensor_is_broadcastable_to(src, self),
46-
"can't broadcast from self to src");
39+
ET_KERNEL_CHECK(
40+
ctx, tensor_is_broadcastable_to(src, in), InvalidArgument, src);
41+
4742
get_broadcast_target_size(
48-
self,
43+
in,
4944
src,
5045
expected_output_size,
5146
kTensorDimensionLimit,
5247
&expected_output_dim);
5348

54-
ET_CHECK_MSG(
55-
Error::Ok ==
56-
resize_tensor(out, {expected_output_size, expected_output_dim}),
57-
"Failed to resize output tensor.");
58-
bool to_be_broadcasted_src = !out.sizes().equals(src.sizes());
49+
ET_KERNEL_CHECK(
50+
ctx,
51+
resize_tensor(out, {expected_output_size, expected_output_dim}) ==
52+
Error::Ok,
53+
InvalidArgument,
54+
out);
5955

60-
// ET_CHECK_SAME_SHAPE2(expected_output_size, self);
61-
const Tensor& broadcasted_src =
62-
to_be_broadcasted_src ? broadcast_tensor(src, out) : src;
56+
ScalarType in_type = in.scalar_type();
57+
ScalarType src_type = src.scalar_type();
6358

64-
if (broadcasted_src.nbytes() > 0) {
65-
// Note that this check is important. It's valid for a tensor with numel 0
66-
// to have a null data pointer, but in some environments it's invalid to
67-
// pass a null pointer to memcpy() even when the size is zero.
68-
memcpy(
69-
out.mutable_data_ptr(),
70-
broadcasted_src.const_data_ptr(),
71-
broadcasted_src.nbytes());
72-
}
73-
if (to_be_broadcasted_src) {
74-
free_broadcast_tensor(broadcasted_src);
75-
}
59+
ET_SWITCH_REAL_TYPES_AND(Bool, in_type, ctx, __func__, CTYPE, [&]() {
60+
ET_SWITCH_REAL_TYPES_AND(Bool, src_type, ctx, __func__, CTYPE_SRC, [&]() {
61+
apply_binary_elementwise_fn<CTYPE, CTYPE_SRC, CTYPE>(
62+
[](const CTYPE val_in, const CTYPE_SRC val_src) {
63+
return convert<CTYPE, CTYPE_SRC>(val_src);
64+
},
65+
in,
66+
src,
67+
out);
68+
});
69+
});
7670

7771
return out;
7872
}

kernels/test/op_copy_test.cpp

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -65,12 +65,9 @@ void test_dtype() {
6565
EXPECT_TENSOR_EQ(src, out_contiguous_ret);
6666
}
6767

68-
TEST(OpCopyTest, AllDtypesSupported) {
69-
if (torch::executor::testing::SupportedFeatures::get()->is_aten) {
70-
GTEST_SKIP() << "ATen kernel test fails";
71-
}
68+
TEST(OpCopyTest, AllRealDtypesSupported) {
7269
#define TEST_ENTRY(ctype, dtype) test_dtype<ctype, ScalarType::dtype>();
73-
ET_FORALL_REAL_TYPES_AND(Bool, TEST_ENTRY);
70+
ET_FORALL_REAL_TYPES(TEST_ENTRY);
7471
#undef TEST_ENTRY
7572
}
7673

0 commit comments

Comments
 (0)