Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions kernels/portable/cpu/scalar_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include <cmath>
#include <limits>

#include <c10/util/overflows.h>
#include <executorch/kernels/portable/cpu/selective_build.h>
#include <executorch/runtime/core/exec_aten/exec_aten.h>
#include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>
Expand Down Expand Up @@ -287,6 +288,29 @@ inline int64_t scalar_to<int64_t>(const Scalar& s) {
: s.to<int64_t>();
}

namespace internal {

template <typename To, typename From>
std::optional<To> check_overflow_cast(From in) {
// Converting to bool can't overflow so we exclude this case from checking.
if (!std::is_same_v<To, bool> && c10::overflows<To, From>(in)) {
return std::nullopt;
}
return static_cast<To>(in);
}

template <typename To>
std::optional<To> check_overflow_scalar_cast(const Scalar& in) {
if (in.isBoolean()) {
return check_overflow_cast<To>(in.to<bool>());
} else if (in.isFloatingPoint()) {
return check_overflow_cast<To>(in.to<double>());
} else {
return check_overflow_cast<To>(in.to<int64_t>());
}
}

} // namespace internal
} // namespace utils
} // namespace native
} // namespace executor
Expand Down
1 change: 1 addition & 0 deletions runtime/core/portable_type/c10/c10/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ def define_common_targets():
"util/complex_utils.h",
"util/floating_point_utils.h",
"util/irange.h",
"util/overflows.h",
],
exported_preprocessor_flags = [
"-DC10_USING_CUSTOM_GENERATED_MACROS",
Expand Down
100 changes: 100 additions & 0 deletions runtime/core/portable_type/c10/c10/util/overflows.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
#pragma once

#include <c10/macros/Macros.h>
#include <c10/util/TypeSafeSignMath.h>
#include <c10/util/complex.h>

#include <cmath>
#include <limits>
#include <type_traits>

namespace c10 {
// In some versions of MSVC, there will be a compiler error when building.
// C4146: unary minus operator applied to unsigned type, result still unsigned
// C4804: unsafe use of type 'bool' in operation
// It can be addressed by disabling the following warning.
#ifdef _MSC_VER
#pragma warning(push)
#pragma warning(disable : 4146)
#pragma warning(disable : 4804)
#pragma warning(disable : 4018)
#endif

// The overflow checks may involve float to int conversion which may
// trigger precision loss warning. Re-enable the warning once the code
// is fixed. See T58053069.
C10_CLANG_DIAGNOSTIC_PUSH()
#if C10_CLANG_HAS_WARNING("-Wimplicit-float-conversion")
C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-float-conversion")
#endif

// bool can be converted to any type.
// Without specializing on bool, in pytorch_linux_trusty_py2_7_9_build:
// `error: comparison of constant '255' with boolean expression is always false`
// for `f > limit::max()` below
template <typename To, typename From>
std::enable_if_t<std::is_same_v<From, bool>, bool> overflows(
From /*f*/,
bool strict_unsigned [[maybe_unused]] = false) {
return false;
}

// skip isnan and isinf check for integral types
template <typename To, typename From>
std::enable_if_t<std::is_integral_v<From> && !std::is_same_v<From, bool>, bool>
overflows(From f, bool strict_unsigned = false) {
using limit = std::numeric_limits<typename scalar_value_type<To>::type>;
if constexpr (!limit::is_signed && std::numeric_limits<From>::is_signed) {
// allow for negative numbers to wrap using two's complement arithmetic.
// For example, with uint8, this allows for `a - b` to be treated as
// `a + 255 * b`.
if (!strict_unsigned) {
return greater_than_max<To>(f) ||
(c10::is_negative(f) &&
-static_cast<uint64_t>(f) > static_cast<uint64_t>(limit::max()));
}
}
return c10::less_than_lowest<To>(f) || greater_than_max<To>(f);
}

template <typename To, typename From>
std::enable_if_t<std::is_floating_point_v<From>, bool> overflows(
From f,
bool strict_unsigned [[maybe_unused]] = false) {
using limit = std::numeric_limits<typename scalar_value_type<To>::type>;
if (limit::has_infinity && std::isinf(static_cast<double>(f))) {
return false;
}
if (!limit::has_quiet_NaN && (f != f)) {
return true;
}
return f < limit::lowest() || f > limit::max();
}

C10_CLANG_DIAGNOSTIC_POP()

#ifdef _MSC_VER
#pragma warning(pop)
#endif

template <typename To, typename From>
std::enable_if_t<is_complex<From>::value, bool> overflows(
From f,
bool strict_unsigned = false) {
// casts from complex to real are considered to overflow if the
// imaginary component is non-zero
if (!is_complex<To>::value && f.imag() != 0) {
return true;
}
// Check for overflow componentwise
// (Technically, the imag overflow check is guaranteed to be false
// when !is_complex<To>, but any optimizer worth its salt will be
// able to figure it out.)
return overflows<
typename scalar_value_type<To>::type,
typename From::value_type>(f.real(), strict_unsigned) ||
overflows<
typename scalar_value_type<To>::type,
typename From::value_type>(f.imag(), strict_unsigned);
}
} // namespace c10
Loading