Skip to content

Commit 91dbce7

Browse files
manuelcandalesfacebook-github-bot
authored andcommitted
Dtype compliance: slice_scatter
Reviewed By: SS-JIA Differential Revision: D48318692 fbshipit-source-id: d39ae3a2dfe8b1e260ab14ce66144862c7f26867
1 parent 97d6b65 commit 91dbce7

File tree

2 files changed

+30
-41
lines changed

2 files changed

+30
-41
lines changed

kernels/portable/cpu/op_slice_scatter.cpp

Lines changed: 28 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,8 @@ void check_input_args(
3939
// Input and output tensors should be the same shape
4040
ET_CHECK_SAME_SHAPE2(input, output);
4141

42-
// All tensors should have the same dtype
43-
ET_CHECK_SAME_DTYPE3(input, src, output);
42+
// Input and output tensors should have the same shape
43+
ET_CHECK_SAME_DTYPE2(input, output);
4444

4545
// The input.dim() shall equal to src.dim()
4646
ET_CHECK_MSG(
@@ -158,28 +158,36 @@ Tensor& slice_scatter_out(
158158
check_input_args(input, src, dim, num_values, step, out);
159159

160160
size_t dim_length = input.size(dim);
161-
162161
size_t leading_dims = getLeadingDims(input, dim);
163162
size_t trailing_dims = getTrailingDims(input, dim);
164163

165-
size_t length_per_step = trailing_dims * input.element_size();
166-
167-
const char* in_data = input.const_data_ptr<char>();
168-
const char* src_data = src.const_data_ptr<char>();
169-
170-
char* out_data = out.mutable_data_ptr<char>();
171-
172164
// To start, copy the input into the output
173-
memcpy(out_data, in_data, input.nbytes());
174-
175-
for (int i = 0; i < leading_dims; i++) {
176-
char* dst = out_data + (i * dim_length + start) * length_per_step;
177-
for (int j = 0; j < num_values; j++) {
178-
memcpy(dst, src_data, length_per_step);
179-
src_data += length_per_step;
180-
dst += step * length_per_step;
181-
}
182-
}
165+
memcpy(out.mutable_data_ptr(), input.const_data_ptr(), input.nbytes());
166+
167+
ScalarType in_type = input.scalar_type();
168+
ScalarType src_type = src.scalar_type();
169+
170+
ET_SWITCH_REAL_TYPES_AND(Bool, in_type, ctx, __func__, CTYPE, [&]() {
171+
ET_SWITCH_REAL_TYPES_AND(Bool, src_type, ctx, __func__, CTYPE_SRC, [&]() {
172+
CTYPE* out_data = out.mutable_data_ptr<CTYPE>();
173+
const CTYPE_SRC* src_data = src.const_data_ptr<CTYPE_SRC>();
174+
175+
size_t src_offset = 0;
176+
177+
for (int i = 0; i < leading_dims; i++) {
178+
size_t out_offset = (i * dim_length + start) * trailing_dims;
179+
for (int j = 0; j < num_values; j++) {
180+
for (size_t k = 0; k < trailing_dims; ++k) {
181+
out_data[out_offset + k] =
182+
convert<CTYPE, CTYPE_SRC>(src_data[src_offset + k]);
183+
}
184+
src_offset += trailing_dims;
185+
out_offset += step * trailing_dims;
186+
}
187+
}
188+
});
189+
});
190+
183191
return out;
184192
}
185193

kernels/test/op_slice_scatter_test.cpp

Lines changed: 2 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -666,12 +666,9 @@ void test_dtype() {
666666
EXPECT_TENSOR_EQ(ret, expect_ret);
667667
}
668668

669-
TEST(OpSliceCopyTensorOutTest, AllDtypesSupported) {
670-
if (torch::executor::testing::SupportedFeatures::get()->is_aten) {
671-
GTEST_SKIP() << "ATen kernel test fails";
672-
}
669+
TEST(OpSliceCopyTensorOutTest, AllRealDtypesSupported) {
673670
#define TEST_ENTRY(ctype, dtype) test_dtype<ctype, ScalarType::dtype>();
674-
ET_FORALL_REAL_TYPES_AND(Bool, TEST_ENTRY);
671+
ET_FORALL_REAL_TYPES(TEST_ENTRY);
675672
#undef TEST_ENTRY
676673
// TODO: Also add tests for half, complex, quantized, and other types. Easiest
677674
// way to do that would be to make TensorFactory support zeros() and ones()
@@ -755,22 +752,6 @@ TEST(OpSliceCopyTensorOutTest, MismatchedOutDtypesDies) {
755752
input, src, /*dim=*/0, /*start=*/0, /*end=*/1, /*step=*/1, out));
756753
}
757754

758-
TEST(OpSliceCopyTensorOutTest, MismatchedSrcDtypesDies) {
759-
if (torch::executor::testing::SupportedFeatures::get()->is_aten) {
760-
GTEST_SKIP() << "ATen kernel can handle src with mismatched dtype";
761-
}
762-
TensorFactory<ScalarType::Int> tf_int;
763-
TensorFactory<ScalarType::Float> tf_float;
764-
Tensor input = tf_int.zeros({1, 2, 2});
765-
Tensor src = tf_float.zeros({1, 2, 2});
766-
767-
// Size is compatible to the output, but a mismatched dtype.
768-
Tensor out = tf_int.ones({1, 2, 2});
769-
770-
ET_EXPECT_KERNEL_FAILURE(op_slice_scatter_out(
771-
input, src, /*dim=*/0, /*start=*/0, /*end=*/1, /*step=*/1, out));
772-
}
773-
774755
TEST(OpSliceCopyTensorOutTest, OutSizeMismatchDimDies) {
775756
if (torch::executor::testing::SupportedFeatures::get()->is_aten) {
776757
GTEST_SKIP() << "ATen kernel can handle out with mismatched dimensions";

0 commit comments

Comments
 (0)