Skip to content

Commit 084b064

Browse files
[ET][Portable] Add util to check for overflow when casting scalar (#12041)
This PR was created by the merge bot to help merge the original PR into the main branch. ghstack PR number: #12011 by @manuelcandales ^ Please use this as the source of truth for the PR details, comments, and reviews ghstack PR base: https://github.com/pytorch/executorch/tree/gh/manuelcandales/118/base ghstack PR head: https://github.com/pytorch/executorch/tree/gh/manuelcandales/118/head Merge bot PR base: https://github.com/pytorch/executorch/tree/gh/manuelcandales/117/orig Merge bot PR head: https://github.com/pytorch/executorch/tree/gh/manuelcandales/118/orig @diff-train-skip-merge --------- Co-authored-by: Manuel Candales <[email protected]> Co-authored-by: Manuel Candales <[email protected]>
1 parent 77a7ace commit 084b064

File tree

3 files changed

+125
-0
lines changed

3 files changed

+125
-0
lines changed

kernels/portable/cpu/scalar_utils.h

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include <cmath>
1212
#include <limits>
1313

14+
#include <c10/util/overflows.h>
1415
#include <executorch/kernels/portable/cpu/selective_build.h>
1516
#include <executorch/runtime/core/exec_aten/exec_aten.h>
1617
#include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>
@@ -287,6 +288,29 @@ inline int64_t scalar_to<int64_t>(const Scalar& s) {
287288
: s.to<int64_t>();
288289
}
289290

291+
namespace internal {
292+
293+
template <typename To, typename From>
294+
std::optional<To> check_overflow_cast(From in) {
295+
// Converting to bool can't overflow so we exclude this case from checking.
296+
if (!std::is_same_v<To, bool> && c10::overflows<To, From>(in)) {
297+
return std::nullopt;
298+
}
299+
return static_cast<To>(in);
300+
}
301+
302+
template <typename To>
303+
std::optional<To> check_overflow_scalar_cast(const Scalar& in) {
304+
if (in.isBoolean()) {
305+
return check_overflow_cast<To>(in.to<bool>());
306+
} else if (in.isFloatingPoint()) {
307+
return check_overflow_cast<To>(in.to<double>());
308+
} else {
309+
return check_overflow_cast<To>(in.to<int64_t>());
310+
}
311+
}
312+
313+
} // namespace internal
290314
} // namespace utils
291315
} // namespace native
292316
} // namespace executor

runtime/core/portable_type/c10/c10/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ def define_common_targets():
112112
"util/complex_utils.h",
113113
"util/floating_point_utils.h",
114114
"util/irange.h",
115+
"util/overflows.h",
115116
],
116117
exported_preprocessor_flags = [
117118
"-DC10_USING_CUSTOM_GENERATED_MACROS",
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
#pragma once
2+
3+
#include <c10/macros/Macros.h>
4+
#include <c10/util/TypeSafeSignMath.h>
5+
#include <c10/util/complex.h>
6+
7+
#include <cmath>
8+
#include <limits>
9+
#include <type_traits>
10+
11+
namespace c10 {
12+
// In some versions of MSVC, there will be a compiler error when building.
13+
// C4146: unary minus operator applied to unsigned type, result still unsigned
14+
// C4804: unsafe use of type 'bool' in operation
15+
// It can be addressed by disabling the following warning.
16+
#ifdef _MSC_VER
17+
#pragma warning(push)
18+
#pragma warning(disable : 4146)
19+
#pragma warning(disable : 4804)
20+
#pragma warning(disable : 4018)
21+
#endif
22+
23+
// The overflow checks may involve float to int conversion which may
24+
// trigger precision loss warning. Re-enable the warning once the code
25+
// is fixed. See T58053069.
26+
C10_CLANG_DIAGNOSTIC_PUSH()
27+
#if C10_CLANG_HAS_WARNING("-Wimplicit-float-conversion")
28+
C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-float-conversion")
29+
#endif
30+
31+
// bool can be converted to any type.
32+
// Without specializing on bool, in pytorch_linux_trusty_py2_7_9_build:
33+
// `error: comparison of constant '255' with boolean expression is always false`
34+
// for `f > limit::max()` below
35+
template <typename To, typename From>
36+
std::enable_if_t<std::is_same_v<From, bool>, bool> overflows(
37+
From /*f*/,
38+
bool strict_unsigned [[maybe_unused]] = false) {
39+
return false;
40+
}
41+
42+
// skip isnan and isinf check for integral types
43+
template <typename To, typename From>
44+
std::enable_if_t<std::is_integral_v<From> && !std::is_same_v<From, bool>, bool>
45+
overflows(From f, bool strict_unsigned = false) {
46+
using limit = std::numeric_limits<typename scalar_value_type<To>::type>;
47+
if constexpr (!limit::is_signed && std::numeric_limits<From>::is_signed) {
48+
// allow for negative numbers to wrap using two's complement arithmetic.
49+
// For example, with uint8, this allows for `a - b` to be treated as
50+
// `a + 255 * b`.
51+
if (!strict_unsigned) {
52+
return greater_than_max<To>(f) ||
53+
(c10::is_negative(f) &&
54+
-static_cast<uint64_t>(f) > static_cast<uint64_t>(limit::max()));
55+
}
56+
}
57+
return c10::less_than_lowest<To>(f) || greater_than_max<To>(f);
58+
}
59+
60+
template <typename To, typename From>
61+
std::enable_if_t<std::is_floating_point_v<From>, bool> overflows(
62+
From f,
63+
bool strict_unsigned [[maybe_unused]] = false) {
64+
using limit = std::numeric_limits<typename scalar_value_type<To>::type>;
65+
if (limit::has_infinity && std::isinf(static_cast<double>(f))) {
66+
return false;
67+
}
68+
if (!limit::has_quiet_NaN && (f != f)) {
69+
return true;
70+
}
71+
return f < limit::lowest() || f > limit::max();
72+
}
73+
74+
C10_CLANG_DIAGNOSTIC_POP()
75+
76+
#ifdef _MSC_VER
77+
#pragma warning(pop)
78+
#endif
79+
80+
template <typename To, typename From>
81+
std::enable_if_t<is_complex<From>::value, bool> overflows(
82+
From f,
83+
bool strict_unsigned = false) {
84+
// casts from complex to real are considered to overflow if the
85+
// imaginary component is non-zero
86+
if (!is_complex<To>::value && f.imag() != 0) {
87+
return true;
88+
}
89+
// Check for overflow componentwise
90+
// (Technically, the imag overflow check is guaranteed to be false
91+
// when !is_complex<To>, but any optimizer worth its salt will be
92+
// able to figure it out.)
93+
return overflows<
94+
typename scalar_value_type<To>::type,
95+
typename From::value_type>(f.real(), strict_unsigned) ||
96+
overflows<
97+
typename scalar_value_type<To>::type,
98+
typename From::value_type>(f.imag(), strict_unsigned);
99+
}
100+
} // namespace c10

0 commit comments

Comments
 (0)