Skip to content

Commit 76e2e12

Browse files
manuelcandalesfacebook-github-bot
authored andcommitted
Dtype compliance: split_copy
Reviewed By: SS-JIA Differential Revision: D48318690 fbshipit-source-id: 935759301c9a4deada20550bb4a2de549646f9ac
1 parent 88eda7b commit 76e2e12

File tree

2 files changed

+31
-27
lines changed

2 files changed

+31
-27
lines changed

kernels/portable/cpu/op_split_copy.cpp

Lines changed: 30 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -87,13 +87,13 @@ void check_args(
8787

8888
// Validate each output.
8989
for (size_t i = 0; i < out.size(); ++i) {
90-
// All output dtypes must match the input type.
90+
// All output dtypes must be the same.
9191
ET_CHECK_MSG(
92-
out[i].scalar_type() == input.scalar_type(),
93-
"out[%zu] dtype %hhd != input dtype %hhd",
92+
out[i].scalar_type() == out[0].scalar_type(),
93+
"out[%zu] dtype %hhd != out[0] dtype %hhd",
9494
i,
9595
out[i].scalar_type(),
96-
input.scalar_type());
96+
out[0].scalar_type());
9797

9898
// All outputs must have the same number of dimensions as the input.
9999
ET_CHECK_MSG(
@@ -170,26 +170,32 @@ void split_copy_Tensor_out(
170170

171171
const size_t leading_dims = getLeadingDims(input, dim);
172172
const size_t trailing_dims = getTrailingDims(input, dim);
173-
174-
const size_t element_size = input.element_size();
175-
const size_t step = input.size(dim) * trailing_dims * element_size;
176-
177-
const char* input_data = input.const_data_ptr<char>();
178-
for (size_t i = 0, e = out.size(); i < e; ++i) {
179-
size_t num_bytes = out[i].size(dim) * trailing_dims * element_size;
180-
if (num_bytes == 0) {
181-
continue;
182-
}
183-
184-
const char* src = input_data;
185-
char* dest = out[i].mutable_data_ptr<char>();
186-
for (size_t j = 0; j < leading_dims; ++j) {
187-
memcpy(dest, src, num_bytes);
188-
src += step;
189-
dest += num_bytes;
190-
}
191-
input_data += num_bytes;
192-
}
173+
const size_t step = input.size(dim) * trailing_dims;
174+
175+
ScalarType in_type = input.scalar_type();
176+
ScalarType out_type = out[0].scalar_type();
177+
178+
ET_SWITCH_REAL_TYPES_AND(Bool, in_type, ctx, __func__, CTYPE_IN, [&]() {
179+
ET_SWITCH_REAL_TYPES_AND(Bool, out_type, ctx, __func__, CTYPE_OUT, [&]() {
180+
const CTYPE_IN* input_data = input.const_data_ptr<CTYPE_IN>();
181+
for (size_t i = 0, e = out.size(); i < e; ++i) {
182+
size_t out_step = out[i].size(dim) * trailing_dims;
183+
if (out_step == 0) {
184+
continue;
185+
}
186+
const CTYPE_IN* src = input_data;
187+
CTYPE_OUT* dest = out[i].mutable_data_ptr<CTYPE_OUT>();
188+
for (size_t j = 0; j < leading_dims; ++j) {
189+
for (size_t k = 0; k < out_step; ++k) {
190+
dest[k] = convert<CTYPE_OUT, CTYPE_IN>(src[k]);
191+
}
192+
src += step;
193+
dest += out_step;
194+
}
195+
input_data += out_step;
196+
}
197+
});
198+
});
193199
}
194200

195201
} // namespace native

kernels/test/op_split_copy_test.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -356,9 +356,7 @@ TEST(OpSplitCopyTensorOutTest, OutOfRangeDimsDie) {
356356
}
357357

358358
TEST(OpSplitCopyTensorOutTest, DtypeMismatchDies) {
359-
if (torch::executor::testing::SupportedFeatures::get()->is_aten) {
360-
GTEST_SKIP() << "ATen kernel can handle dtype mismatch";
361-
}
359+
GTEST_SKIP() << "ATen kernel can handle dtype mismatch";
362360
TensorFactory<ScalarType::Int> tf_int;
363361
TensorListFactory<ScalarType::Int> tlf_int;
364362
TensorListFactory<ScalarType::Float> tlf_float;

0 commit comments

Comments
 (0)