Skip to content

Commit 4b8a584

Browse files
[ET][Portable] Check scalar overflow: op_scalar_tensor (#12082)
This PR was created by the merge bot to help merge the original PR into the main branch. ghstack PR number: #12025 by @manuelcandales ^ Please use this as the source of truth for the PR details, comments, and reviews ghstack PR base: https://github.com/pytorch/executorch/tree/gh/manuelcandales/123/base ghstack PR head: https://github.com/pytorch/executorch/tree/gh/manuelcandales/123/head Merge bot PR base: https://github.com/pytorch/executorch/tree/gh/manuelcandales/122/orig Merge bot PR head: https://github.com/pytorch/executorch/tree/gh/manuelcandales/123/orig @diff-train-skip-merge --------- Co-authored-by: Manuel Candales <[email protected]>
1 parent 5cbd790 commit 4b8a584

File tree

2 files changed

+16
-11
lines changed

2 files changed

+16
-11
lines changed

kernels/portable/cpu/op_scalar_tensor.cpp

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -24,17 +24,11 @@ scalar_tensor_out(KernelRuntimeContext& ctx, const Scalar& s, Tensor& out) {
2424

2525
constexpr auto name = "scalar_tensor.out";
2626

27-
if (s.isFloatingPoint() &&
28-
executorch::runtime::isIntegralType(out_type, false)) {
29-
ET_SWITCH_INT_TYPES(out_type, ctx, name, CTYPE, [&]() {
30-
out.mutable_data_ptr<CTYPE>()[0] =
31-
static_cast<CTYPE>(utils::scalar_to<int64_t>(s));
32-
});
33-
} else {
34-
ET_SWITCH_REALHBBF16_TYPES(out_type, ctx, name, CTYPE, [&]() {
35-
out.mutable_data_ptr<CTYPE>()[0] = utils::scalar_to<CTYPE>(s);
36-
});
37-
}
27+
ET_SWITCH_REALHBBF16_TYPES(out_type, ctx, name, CTYPE, [&]() {
28+
auto opt_val_casted = utils::internal::check_overflow_scalar_cast<CTYPE>(s);
29+
ET_KERNEL_CHECK(ctx, opt_val_casted.has_value(), InvalidArgument, );
30+
out.mutable_data_ptr<CTYPE>()[0] = opt_val_casted.value();
31+
});
3832

3933
return out;
4034
}

kernels/test/op_scalar_tensor_test.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
*/
88

99
#include <executorch/kernels/test/FunctionHeaderWrapper.h> // Declares the operator
10+
#include <executorch/kernels/test/ScalarOverflowTestMacros.h>
1011
#include <executorch/kernels/test/TestUtil.h>
1112
#include <executorch/kernels/test/supported_features.h>
1213
#include <executorch/runtime/core/exec_aten/exec_aten.h>
@@ -71,6 +72,14 @@ class OpScalarTensorOutTest : public OperatorTest {
7172

7273
ET_EXPECT_KERNEL_FAILURE(context_, op_scalar_tensor_out(value, out));
7374
}
75+
76+
template <ScalarType DTYPE>
77+
void expect_bad_scalar_value_dies(const Scalar& bad_value) {
78+
TensorFactory<DTYPE> tf;
79+
Tensor out = tf.zeros({});
80+
81+
ET_EXPECT_KERNEL_FAILURE(context_, op_scalar_tensor_out(bad_value, out));
82+
}
7483
};
7584

7685
#define GENERATE_TEST_0D(ctype, dtype) \
@@ -131,3 +140,5 @@ TEST_F(OpScalarTensorOutTest, HalfSupport) {
131140
op_scalar_tensor_out(INFINITY, out);
132141
EXPECT_TENSOR_CLOSE(out, tf.make({}, {INFINITY}));
133142
}
143+
144+
GENERATE_SCALAR_OVERFLOW_TESTS(OpScalarTensorOutTest)

0 commit comments

Comments
 (0)