|
7 | 7 | */
|
8 | 8 |
|
9 | 9 | #include <executorch/kernels/test/FunctionHeaderWrapper.h> // Declares the operator
|
| 10 | +#include <executorch/kernels/test/ScalarOverflowTestMacros.h> |
10 | 11 | #include <executorch/kernels/test/TestUtil.h>
|
11 | 12 | #include <executorch/kernels/test/supported_features.h>
|
12 | 13 | #include <executorch/runtime/core/exec_aten/exec_aten.h>
|
@@ -364,6 +365,19 @@ class OpScatterValueOutTest : public OperatorTest {
|
364 | 365 | op_scatter_value_out(input, 2, index, value, out);
|
365 | 366 | EXPECT_TENSOR_EQ(out, expected);
|
366 | 367 | }
|
| 368 | + |
| 369 | + template <ScalarType DTYPE> |
| 370 | + void expect_bad_scalar_value_dies(const Scalar& bad_value) { |
| 371 | + TensorFactory<DTYPE> tf; |
| 372 | + TensorFactory<ScalarType::Long> tf_index; |
| 373 | + |
| 374 | + Tensor self = tf.ones({2, 2}); |
| 375 | + Tensor index = tf_index.zeros({2, 2}); |
| 376 | + Tensor out = tf.zeros({2, 2}); |
| 377 | + |
| 378 | + ET_EXPECT_KERNEL_FAILURE( |
| 379 | + context_, op_scatter_value_out(self, 0, index, bad_value, out)); |
| 380 | + } |
367 | 381 | };
|
368 | 382 |
|
369 | 383 | TEST_F(OpScatterSrcOutTest, AllValidInputOutputSupport) {
|
@@ -652,3 +666,5 @@ TEST_F(OpScatterSrcOutTest, InvalidOneDimInputAndZeroDimIndex) {
|
652 | 666 | ET_EXPECT_KERNEL_FAILURE(
|
653 | 667 | context_, op_scatter_src_out(self, 0, index, src, out));
|
654 | 668 | }
|
| 669 | + |
| 670 | +GENERATE_SCALAR_OVERFLOW_TESTS(OpScatterValueOutTest) |
0 commit comments