@@ -109,6 +109,36 @@ class OpIndexTensorOutTest : public OperatorTest {
109
109
110
110
ET_FORALL_REALHBF16_TYPES (TEST_ENTRY);
111
111
112
+ #undef TEST_ENTRY
113
+ }
114
+
115
+ template <executorch::aten::ScalarType INPUT_DTYPE>
116
+ void test_indices_with_only_null_tensors_supported () {
117
+ TensorFactory<INPUT_DTYPE> tf;
118
+
119
+ Tensor x = tf.make ({2 , 3 }, {1 , 2 , 3 , 4 , 5 , 6 });
120
+ Tensor out = tf.zeros ({2 , 3 });
121
+
122
+ std::array<optional<Tensor>, 1 > indices1 = {optional<Tensor>()};
123
+ op_index_tensor_out (x, indices1, out);
124
+ EXPECT_TENSOR_EQ (out, x);
125
+
126
+ out = tf.zeros ({2 , 3 });
127
+ std::array<optional<Tensor>, 2 > indices2 = {
128
+ optional<Tensor>(), std::optional<Tensor>()};
129
+ op_index_tensor_out (x, indices2, out);
130
+ EXPECT_TENSOR_EQ (out, x);
131
+ }
132
+
133
+ /* *
134
+ * Test indices with only null tensors for all input data types
135
+ */
136
+ void test_indices_with_only_null_tensors_enumerate_in_types () {
137
+ #define TEST_ENTRY (ctype, dtype ) \
138
+ test_indices_with_only_null_tensors_supported<ScalarType::dtype>();
139
+
140
+ ET_FORALL_REALHBF16_TYPES (TEST_ENTRY);
141
+
112
142
#undef TEST_ENTRY
113
143
}
114
144
@@ -405,21 +435,19 @@ TEST_F(OpIndexTensorOutTest, IndicesWithOnlyNullTensorsSupported) {
405
435
if (torch::executor::testing::SupportedFeatures::get ()->is_aten ) {
406
436
GTEST_SKIP () << " ATen kernel test fails" ;
407
437
}
408
- TensorFactory<ScalarType::Double> tf;
438
+ test_indices_with_only_null_tensors_enumerate_in_types ();
439
+ }
409
440
441
+ TEST_F (OpIndexTensorOutTest, TooManyNullIndices) {
442
+ TensorFactory<ScalarType::Double> tf;
410
443
Tensor x = tf.make ({2 , 3 }, {1 ., 2 ., 3 ., 4 ., 5 ., 6 .});
411
- std::array<optional<Tensor>, 1 > indices0 = {optional<Tensor>()};
412
- run_test_cases (x, indices0, x);
413
-
414
- std::array<optional<Tensor>, 2 > indices1 = {
415
- optional<Tensor>(), std::optional<Tensor>()};
416
- run_test_cases (x, indices1, x);
417
-
418
- std::array<optional<Tensor>, 3 > indices2 = {
444
+ std::array<optional<Tensor>, 3 > indices = {
419
445
optional<Tensor>(), std::optional<Tensor>(), std::optional<Tensor>()};
420
446
Tensor out = tf.ones ({2 , 3 });
421
447
ET_EXPECT_KERNEL_FAILURE_WITH_MSG (
422
- context_, op_index_tensor_out (x, indices2, out), " " );
448
+ context_,
449
+ op_index_tensor_out (x, indices, out),
450
+ " Indexing too many dimensions" );
423
451
}
424
452
425
453
TEST_F (OpIndexTensorOutTest, EmptyIndicesSupported) {
0 commit comments