Skip to content

Commit 7d92e60

Browse files
[ET][Portable] Add util to check for overflow when casting scalar
Differential Revision: [D77382647](https://our.internmc.facebook.com/intern/diff/D77382647/) [ghstack-poisoned]
1 parent 8a934d5 commit 7d92e60

File tree

1 file changed

+24
-0
lines changed

1 file changed

+24
-0
lines changed

kernels/portable/cpu/scalar_utils.h

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include <cmath>
1212
#include <limits>
1313

14+
#include <c10/util/overflows.h>
1415
#include <executorch/kernels/portable/cpu/selective_build.h>
1516
#include <executorch/runtime/core/exec_aten/exec_aten.h>
1617
#include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>
@@ -287,6 +288,29 @@ inline int64_t scalar_to<int64_t>(const Scalar& s) {
287288
: s.to<int64_t>();
288289
}
289290

291+
namespace internal {
292+
293+
template <typename To, typename From>
294+
std::optional<To> check_overflow_cast(From in) {
295+
// Converting to bool can't overflow so we exclude this case from checking.
296+
if (!std::is_same_v<To, bool> && c10::overflows<To, From>(in)) {
297+
return std::nullopt;
298+
}
299+
return static_cast<To>(in);
300+
}
301+
302+
template <typename To>
303+
std::optional<To> check_overflow_scalar_cast(const Scalar& in) {
304+
if (in.isBoolean()) {
305+
return check_overflow_cast<To>(in.to<bool>());
306+
} else if (in.isFloatingPoint()) {
307+
return check_overflow_cast<To>(in.to<double>());
308+
} else {
309+
return check_overflow_cast<To>(in.to<int64_t>());
310+
}
311+
}
312+
313+
} // namespace internal
290314
} // namespace utils
291315
} // namespace native
292316
} // namespace executor

0 commit comments

Comments
 (0)