Skip to content

Commit 88eda7b

Browse files
manuelcandalesfacebook-github-bot
authored andcommitted
Dtype compliance: unbind_copy
Reviewed By: SS-JIA Differential Revision: D48318691 fbshipit-source-id: 6d01021e2e567addeb5faaf5e9714c9d34139528
1 parent 91dbce7 commit 88eda7b

File tree

2 files changed

+37
-36
lines changed

2 files changed

+37
-36
lines changed

kernels/portable/cpu/op_unbind_copy.cpp

Lines changed: 27 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -41,13 +41,13 @@ void check_args(const Tensor& input, int64_t dim, TensorList out) {
4141

4242
// Validate each output.
4343
for (size_t i = 0; i < out.size(); ++i) {
44-
// All output dtypes must match the input type.
44+
// All output dtypes must be the same.
4545
ET_CHECK_MSG(
46-
out[i].scalar_type() == input.scalar_type(),
47-
"out[%zu] dtype %hhd != input dtype %hhd",
46+
out[i].scalar_type() == out[0].scalar_type(),
47+
"out[%zu] dtype %hhd != out[0] dtype %hhd",
4848
i,
4949
out[i].scalar_type(),
50-
input.scalar_type());
50+
out[0].scalar_type());
5151

5252
// output tensor must have # of dims = input.dim() -1
5353
ET_CHECK_MSG(
@@ -97,25 +97,29 @@ void unbind_copy_int_out(
9797

9898
const size_t leading_dims = getLeadingDims(input, dim);
9999
const size_t trailing_dims = getTrailingDims(input, dim);
100-
101-
const size_t element_size = input.element_size();
102-
const size_t step = input.size(dim) * trailing_dims * element_size;
103-
104-
const char* input_data = input.const_data_ptr<char>();
105-
for (size_t i = 0, e = out.size(); i < e; ++i) {
106-
size_t num_bytes = trailing_dims * element_size;
107-
// num_bytes should not be zero because trailing_dims
108-
// will at least return 1
109-
110-
const char* src = input_data;
111-
char* dest = out[i].mutable_data_ptr<char>();
112-
for (size_t j = 0; j < leading_dims; ++j) {
113-
memcpy(dest, src, num_bytes);
114-
src += step;
115-
dest += num_bytes;
116-
}
117-
input_data += num_bytes;
118-
}
100+
const size_t step = input.size(dim) * trailing_dims;
101+
102+
ScalarType in_type = input.scalar_type();
103+
ScalarType out_type = out[0].scalar_type();
104+
105+
ET_SWITCH_REAL_TYPES_AND(Bool, in_type, ctx, __func__, CTYPE_IN, [&]() {
106+
ET_SWITCH_REAL_TYPES_AND(Bool, out_type, ctx, __func__, CTYPE_OUT, [&]() {
107+
const CTYPE_IN* const input_data = input.const_data_ptr<CTYPE_IN>();
108+
for (size_t i = 0, e = out.size(); i < e; ++i) {
109+
size_t input_offset = i * trailing_dims;
110+
CTYPE_OUT* const dest = out[i].mutable_data_ptr<CTYPE_OUT>();
111+
size_t dest_offset = 0;
112+
for (size_t j = 0; j < leading_dims; ++j) {
113+
for (size_t k = 0; k < trailing_dims; ++k) {
114+
dest[dest_offset + k] =
115+
convert<CTYPE_OUT, CTYPE_IN>(input_data[input_offset + k]);
116+
}
117+
input_offset += step;
118+
dest_offset += trailing_dims;
119+
}
120+
}
121+
});
122+
});
119123
}
120124

121125
} // namespace native

kernels/test/op_unbind_copy_test.cpp

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,8 @@ Tensor make1x2x3(TensorFactory<DTYPE>& tf) {
4848

4949
template <ScalarType DTYPE>
5050
void test_unbind_dim0() {
51-
TensorFactory<ScalarType::Int> tf;
52-
TensorListFactory<ScalarType::Int> tlf;
51+
TensorFactory<DTYPE> tf;
52+
TensorListFactory<DTYPE> tlf;
5353

5454
// clang-format off
5555
std::vector<Tensor> expected_out = {
@@ -117,8 +117,8 @@ void test_unbind_dim1() {
117117

118118
template <ScalarType DTYPE>
119119
void test_unbind_dim2() {
120-
TensorFactory<ScalarType::Int> tf;
121-
TensorListFactory<ScalarType::Int> tlf;
120+
TensorFactory<DTYPE> tf;
121+
TensorListFactory<DTYPE> tlf;
122122

123123
// Splitting on dim=N with split_size=2 will produce a list of tensors where
124124
// the max dim[N] is 2, and the other dims are the same as the input.
@@ -164,24 +164,21 @@ void test_unbind_dim2() {
164164
EXPECT_TENSOR_LISTS_EQ(expected_out, out2);
165165
}
166166

167-
TEST(OpUnbindCopyIntOutTest, Unbind1x2x3OnDim0AllSupportedDtypes) {
167+
TEST(OpUnbindCopyIntOutTest, Unbind1x2x3OnDim0AllRealDtypes) {
168168
#define TEST_ENTRY(ctype, dtype) test_unbind_dim0<ScalarType::dtype>();
169-
ET_FORALL_REAL_TYPES_AND(Bool, TEST_ENTRY);
169+
ET_FORALL_REAL_TYPES(TEST_ENTRY);
170170
#undef TEST_ENTRY
171171
}
172172

173-
TEST(OpUnbindCopyIntOutTest, Unbind1x2x3OnDim1AllSupportedDTypes) {
174-
if (torch::executor::testing::SupportedFeatures::get()->is_aten) {
175-
GTEST_SKIP() << "ATen kernel test fails";
176-
}
173+
TEST(OpUnbindCopyIntOutTest, Unbind1x2x3OnDim1AllRealDTypes) {
177174
#define TEST_ENTRY(ctype, dtype) test_unbind_dim1<ScalarType::dtype>();
178-
ET_FORALL_REAL_TYPES_AND(Bool, TEST_ENTRY);
175+
ET_FORALL_REAL_TYPES(TEST_ENTRY);
179176
#undef TEST_ENTRY
180177
}
181178

182-
TEST(OpUnbindCopyIntOutTest, Unbind1x2x3OnDim2) {
179+
TEST(OpUnbindCopyIntOutTest, Unbind1x2x3OnDim2AllRealDTypes) {
183180
#define TEST_ENTRY(ctype, dtype) test_unbind_dim2<ScalarType::dtype>();
184-
ET_FORALL_REAL_TYPES_AND(Bool, TEST_ENTRY);
181+
ET_FORALL_REAL_TYPES(TEST_ENTRY);
185182
#undef TEST_ENTRY
186183
}
187184

0 commit comments

Comments
 (0)