Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions kernels/portable/cpu/scalar_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include <cmath>
#include <limits>

#include <c10/util/overflows.h>
#include <executorch/kernels/portable/cpu/selective_build.h>
#include <executorch/runtime/core/exec_aten/exec_aten.h>
#include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>
Expand Down Expand Up @@ -287,6 +288,29 @@ inline int64_t scalar_to<int64_t>(const Scalar& s) {
: s.to<int64_t>();
}

namespace internal {

template <typename To, typename From>
std::optional<To> check_overflow_cast(From in) {
// Converting to bool can't overflow so we exclude this case from checking.
if (!std::is_same_v<To, bool> && c10::overflows<To, From>(in)) {
return std::nullopt;
}
return static_cast<To>(in);
}

template <typename To>
std::optional<To> check_overflow_scalar_cast(const Scalar& in) {
if (in.isBoolean()) {
return check_overflow_cast<To>(in.to<bool>());
} else if (in.isFloatingPoint()) {
return check_overflow_cast<To>(in.to<double>());
} else {
return check_overflow_cast<To>(in.to<int64_t>());
}
}

} // namespace internal
} // namespace utils
} // namespace native
} // namespace executor
Expand Down
Loading