Skip to content

Commit f25265b

Browse files
[ET][Portable] Check scalar overflow: op_full (#12080)
This PR was created by the merge bot to help merge the original PR into the main branch. ghstack PR number: #12014 by @manuelcandales ^ Please use this as the source of truth for the PR details, comments, and reviews ghstack PR base: https://github.com/pytorch/executorch/tree/gh/manuelcandales/121/base ghstack PR head: https://github.com/pytorch/executorch/tree/gh/manuelcandales/121/head Merge bot PR base: https://github.com/pytorch/executorch/tree/main Merge bot PR head: https://github.com/pytorch/executorch/tree/gh/manuelcandales/121/orig @diff-train-skip-merge Co-authored-by: Manuel Candales <[email protected]>
1 parent cc3a35f commit f25265b

File tree

2 files changed

+17
-15
lines changed

2 files changed

+17
-15
lines changed

kernels/portable/cpu/op_full.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,10 @@ Tensor& full_out(
3737
constexpr auto name = "full.out";
3838

3939
ET_SWITCH_REALHBBF16_TYPES(out_type, ctx, name, CTYPE_OUT, [&] {
40-
CTYPE_OUT val_casted = utils::scalar_to<CTYPE_OUT>(fill_value);
40+
auto opt_val_casted =
41+
utils::internal::check_overflow_scalar_cast<CTYPE_OUT>(fill_value);
42+
ET_KERNEL_CHECK(ctx, opt_val_casted.has_value(), InvalidArgument, );
43+
auto val_casted = opt_val_casted.value();
4144
auto data_out = out.mutable_data_ptr<CTYPE_OUT>();
4245
for (const auto i : c10::irange(out.numel())) {
4346
data_out[i] = val_casted;

kernels/test/op_full_test.cpp

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
*/
88

99
#include <executorch/kernels/test/FunctionHeaderWrapper.h> // Declares the operator
10+
#include <executorch/kernels/test/ScalarOverflowTestMacros.h>
1011
#include <executorch/kernels/test/TestUtil.h>
1112
#include <executorch/kernels/test/supported_features.h>
1213
#include <executorch/runtime/core/exec_aten/exec_aten.h>
@@ -59,6 +60,17 @@ class OpFullOutTest : public OperatorTest {
5960
op_full_out(aref, 1.0, out);
6061
EXPECT_TENSOR_EQ(out, tf.ones(size_int32_t));
6162
}
63+
64+
template <ScalarType DTYPE>
65+
void expect_bad_scalar_value_dies(const Scalar& bad_value) {
66+
TensorFactory<DTYPE> tf;
67+
std::vector<int32_t> sizes = {2, 2};
68+
std::vector<int64_t> sizes_int64_t(sizes.begin(), sizes.end());
69+
auto aref = IntArrayRef(sizes_int64_t.data(), sizes_int64_t.size());
70+
Tensor out = tf.zeros(sizes);
71+
72+
ET_EXPECT_KERNEL_FAILURE(context_, op_full_out(aref, bad_value, out));
73+
}
6274
};
6375

6476
#define GENERATE_TEST(_, DTYPE) \
@@ -72,20 +84,7 @@ class OpFullOutTest : public OperatorTest {
7284

7385
ET_FORALL_REALHBF16_TYPES(GENERATE_TEST)
7486

75-
TEST_F(OpFullOutTest, ValueOverflow) {
76-
if (torch::executor::testing::SupportedFeatures::get()->is_aten) {
77-
GTEST_SKIP() << "ATen kernel doesn't handle overflow";
78-
}
79-
TensorFactory<ScalarType::Byte> tf;
80-
81-
std::vector<int64_t> sizes_int64_t_vec = {2, 3};
82-
std::vector<int32_t> sizes_in32_t_vec = {2, 3};
83-
auto sizes = IntArrayRef(sizes_int64_t_vec.data(), sizes_int64_t_vec.size());
84-
85-
Tensor out = tf.zeros(sizes_in32_t_vec);
86-
87-
op_full_out(sizes, 1000, out);
88-
}
87+
GENERATE_SCALAR_OVERFLOW_TESTS(OpFullOutTest)
8988

9089
TEST_F(OpFullOutTest, HalfSupport) {
9190
TensorFactory<ScalarType::Half> tf;

0 commit comments

Comments
 (0)