Skip to content

Commit c638851

Browse files
Fix BFloat16 support for op_index
Differential Revision: D82134506 Pull Request resolved: #14167
1 parent 3124a6b commit c638851

File tree

2 files changed

+39
-11
lines changed

2 files changed

+39
-11
lines changed

kernels/portable/cpu/op_index.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,7 @@ Tensor& index_Tensor_out(
213213
if (block_count == 0) {
214214
ET_KERNEL_CHECK(
215215
ctx, resize_tensor(out, in.sizes()) == Error::Ok, InvalidArgument, out);
216-
ET_SWITCH_REALHB_TYPES(in_type, ctx, "index.Tensor_out", CTYPE, [&]() {
216+
ET_SWITCH_REALHBBF16_TYPES(in_type, ctx, "index.Tensor_out", CTYPE, [&]() {
217217
const CTYPE* const in_data = in.const_data_ptr<CTYPE>();
218218
CTYPE* const out_data = out.mutable_data_ptr<CTYPE>();
219219
memcpy(out_data, in_data, in.nbytes());

kernels/test/op_index_test.cpp

Lines changed: 38 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,36 @@ class OpIndexTensorOutTest : public OperatorTest {
109109

110110
ET_FORALL_REALHBF16_TYPES(TEST_ENTRY);
111111

112+
#undef TEST_ENTRY
113+
}
114+
115+
template <executorch::aten::ScalarType INPUT_DTYPE>
116+
void test_indices_with_only_null_tensors_supported() {
117+
TensorFactory<INPUT_DTYPE> tf;
118+
119+
Tensor x = tf.make({2, 3}, {1, 2, 3, 4, 5, 6});
120+
Tensor out = tf.zeros({2, 3});
121+
122+
std::array<optional<Tensor>, 1> indices1 = {optional<Tensor>()};
123+
op_index_tensor_out(x, indices1, out);
124+
EXPECT_TENSOR_EQ(out, x);
125+
126+
out = tf.zeros({2, 3});
127+
std::array<optional<Tensor>, 2> indices2 = {
128+
optional<Tensor>(), std::optional<Tensor>()};
129+
op_index_tensor_out(x, indices2, out);
130+
EXPECT_TENSOR_EQ(out, x);
131+
}
132+
133+
/**
134+
* Test indices with only null tensors for all input data types
135+
*/
136+
void test_indices_with_only_null_tensors_enumerate_in_types() {
137+
#define TEST_ENTRY(ctype, dtype) \
138+
test_indices_with_only_null_tensors_supported<ScalarType::dtype>();
139+
140+
ET_FORALL_REALHBF16_TYPES(TEST_ENTRY);
141+
112142
#undef TEST_ENTRY
113143
}
114144

@@ -405,21 +435,19 @@ TEST_F(OpIndexTensorOutTest, IndicesWithOnlyNullTensorsSupported) {
405435
if (torch::executor::testing::SupportedFeatures::get()->is_aten) {
406436
GTEST_SKIP() << "ATen kernel test fails";
407437
}
408-
TensorFactory<ScalarType::Double> tf;
438+
test_indices_with_only_null_tensors_enumerate_in_types();
439+
}
409440

441+
TEST_F(OpIndexTensorOutTest, TooManyNullIndices) {
442+
TensorFactory<ScalarType::Double> tf;
410443
Tensor x = tf.make({2, 3}, {1., 2., 3., 4., 5., 6.});
411-
std::array<optional<Tensor>, 1> indices0 = {optional<Tensor>()};
412-
run_test_cases(x, indices0, x);
413-
414-
std::array<optional<Tensor>, 2> indices1 = {
415-
optional<Tensor>(), std::optional<Tensor>()};
416-
run_test_cases(x, indices1, x);
417-
418-
std::array<optional<Tensor>, 3> indices2 = {
444+
std::array<optional<Tensor>, 3> indices = {
419445
optional<Tensor>(), std::optional<Tensor>(), std::optional<Tensor>()};
420446
Tensor out = tf.ones({2, 3});
421447
ET_EXPECT_KERNEL_FAILURE_WITH_MSG(
422-
context_, op_index_tensor_out(x, indices2, out), "");
448+
context_,
449+
op_index_tensor_out(x, indices, out),
450+
"Indexing too many dimensions");
423451
}
424452

425453
TEST_F(OpIndexTensorOutTest, EmptyIndicesSupported) {

0 commit comments

Comments
 (0)