Skip to content

Commit c1f4d59

Browse files
[ET][Portable] Check scalar overflow: op_scatter (#12084)
This PR was created by the merge bot to help merge the original PR into the main branch. ghstack PR number: #12028 by @manuelcandales ^ Please use this as the source of truth for the PR details, comments, and reviews ghstack PR base: https://github.com/pytorch/executorch/tree/gh/manuelcandales/126/base ghstack PR head: https://github.com/pytorch/executorch/tree/gh/manuelcandales/126/head Merge bot PR base: https://github.com/pytorch/executorch/tree/gh/manuelcandales/124/orig Merge bot PR head: https://github.com/pytorch/executorch/tree/gh/manuelcandales/126/orig @diff-train-skip-merge --------- Co-authored-by: Manuel Candales <[email protected]>
1 parent 42fda56 commit c1f4d59

File tree

2 files changed

+19
-1
lines changed

2 files changed

+19
-1
lines changed

kernels/portable/cpu/op_scatter.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,9 @@ Tensor& scatter_value_out(
154154
constexpr auto name = "scatter.value_out";
155155

156156
ET_SWITCH_REALHBBF16_TYPES(in.scalar_type(), ctx, name, CTYPE, [&]() {
157-
const CTYPE val = utils::scalar_to<CTYPE>(value);
157+
auto opt_val = utils::internal::check_overflow_scalar_cast<CTYPE>(value);
158+
ET_KERNEL_CHECK(ctx, opt_val.has_value(), InvalidArgument, );
159+
auto val = opt_val.value();
158160
scatter_value_helper<CTYPE>(in, dim, index, val, out);
159161
});
160162

kernels/test/op_scatter_test.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
*/
88

99
#include <executorch/kernels/test/FunctionHeaderWrapper.h> // Declares the operator
10+
#include <executorch/kernels/test/ScalarOverflowTestMacros.h>
1011
#include <executorch/kernels/test/TestUtil.h>
1112
#include <executorch/kernels/test/supported_features.h>
1213
#include <executorch/runtime/core/exec_aten/exec_aten.h>
@@ -364,6 +365,19 @@ class OpScatterValueOutTest : public OperatorTest {
364365
op_scatter_value_out(input, 2, index, value, out);
365366
EXPECT_TENSOR_EQ(out, expected);
366367
}
368+
369+
template <ScalarType DTYPE>
370+
void expect_bad_scalar_value_dies(const Scalar& bad_value) {
371+
TensorFactory<DTYPE> tf;
372+
TensorFactory<ScalarType::Long> tf_index;
373+
374+
Tensor self = tf.ones({2, 2});
375+
Tensor index = tf_index.zeros({2, 2});
376+
Tensor out = tf.zeros({2, 2});
377+
378+
ET_EXPECT_KERNEL_FAILURE(
379+
context_, op_scatter_value_out(self, 0, index, bad_value, out));
380+
}
367381
};
368382

369383
TEST_F(OpScatterSrcOutTest, AllValidInputOutputSupport) {
@@ -652,3 +666,5 @@ TEST_F(OpScatterSrcOutTest, InvalidOneDimInputAndZeroDimIndex) {
652666
ET_EXPECT_KERNEL_FAILURE(
653667
context_, op_scatter_src_out(self, 0, index, src, out));
654668
}
669+
670+
GENERATE_SCALAR_OVERFLOW_TESTS(OpScatterValueOutTest)

0 commit comments

Comments
 (0)