Skip to content

Commit 42fda56

Browse files
[ET][Portable] Check scalar overflow: op_constant_pad_nd (#12083)
This PR was created by the merge bot to help merge the original PR into the main branch. ghstack PR number: #12026 by @manuelcandales ^ Please use this as the source of truth for the PR details, comments, and reviews ghstack PR base: https://github.com/pytorch/executorch/tree/gh/manuelcandales/124/base ghstack PR head: https://github.com/pytorch/executorch/tree/gh/manuelcandales/124/head Merge bot PR base: https://github.com/pytorch/executorch/tree/gh/manuelcandales/123/orig Merge bot PR head: https://github.com/pytorch/executorch/tree/gh/manuelcandales/124/orig @diff-train-skip-merge --------- Co-authored-by: Manuel Candales <[email protected]>
1 parent 4b8a584 commit 42fda56

File tree

2 files changed

+22
-1
lines changed

2 files changed

+22
-1
lines changed

kernels/portable/cpu/op_constant_pad_nd.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,10 @@ Tensor& constant_pad_nd_out(
185185
ScalarType in_type = in.scalar_type();
186186

187187
ET_SWITCH_REALHBBF16_TYPES(in_type, ctx, "constant_pad_nd.out", CTYPE, [&]() {
188-
const CTYPE value_casted = utils::scalar_to<CTYPE>(value);
188+
auto opt_value_casted =
189+
utils::internal::check_overflow_scalar_cast<CTYPE>(value);
190+
ET_KERNEL_CHECK(ctx, opt_value_casted.has_value(), InvalidArgument, );
191+
auto value_casted = opt_value_casted.value();
189192
constant_pad_nd_out_impl<CTYPE>(in, pad, value_casted, out);
190193
});
191194

kernels/test/op_constant_pad_nd_test.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
*/
88

99
#include <executorch/kernels/test/FunctionHeaderWrapper.h> // Declares the operator
10+
#include <executorch/kernels/test/ScalarOverflowTestMacros.h>
1011
#include <executorch/kernels/test/TestUtil.h>
1112
#include <executorch/kernels/test/supported_features.h>
1213
#include <executorch/runtime/core/exec_aten/exec_aten.h>
@@ -347,6 +348,21 @@ class OpConstantPadNDOutTest : public OperatorTest {
347348
op_constant_pad_nd_out(self, padding_ref, 7, out);
348349
EXPECT_TENSOR_CLOSE(out, expected);
349350
}
351+
352+
template <ScalarType DTYPE>
353+
void expect_bad_scalar_value_dies(const Scalar& bad_value) {
354+
TensorFactory<DTYPE> tf;
355+
const std::vector<int32_t> sizes = {2, 2};
356+
const std::vector<int32_t> sizes_out = {2, 4};
357+
const std::vector<int64_t> padding = {1, 1};
358+
359+
IntArrayRef padding_ref = IntArrayRef(padding.data(), padding.size());
360+
Tensor self = tf.ones(sizes);
361+
Tensor out = tf.zeros(sizes_out);
362+
363+
ET_EXPECT_KERNEL_FAILURE(
364+
context_, op_constant_pad_nd_out(self, padding_ref, bad_value, out));
365+
}
350366
};
351367

352368
TEST_F(OpConstantPadNDOutTest, TestPadDim2) {
@@ -465,3 +481,5 @@ TEST_F(OpConstantPadNDOutTest, IncorrectOutputShapeFail) {
465481
ET_EXPECT_KERNEL_FAILURE(
466482
context_, op_constant_pad_nd_out(self, padding_ref, 0, out));
467483
}
484+
485+
GENERATE_SCALAR_OVERFLOW_TESTS(OpConstantPadNDOutTest)

0 commit comments

Comments
 (0)