|
| 1 | +//===- ArithmeticUtils.h - Arithmetic helper functions ----------*- C++ -*-===// |
| 2 | +// |
| 3 | +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | +// See https://llvm.org/LICENSE.txt for license information. |
| 5 | +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | +// |
| 7 | +//===----------------------------------------------------------------------===// |
| 8 | +// |
| 9 | +// This header is not part of the public API. It is placed in the |
| 10 | +// includes directory only because that's required by the implementations |
| 11 | +// of template-classes. |
| 12 | +// |
| 13 | +// This file is part of the lightweight runtime support library for sparse |
| 14 | +// tensor manipulations. The functionality of the support library is meant |
| 15 | +// to simplify benchmarking, testing, and debugging MLIR code operating on |
| 16 | +// sparse tensors. However, the provided functionality is **not** part of |
| 17 | +// core MLIR itself. |
| 18 | +// |
| 19 | +//===----------------------------------------------------------------------===// |
| 20 | + |
| 21 | +#ifndef MLIR_EXECUTIONENGINE_SPARSETENSOR_ARITHMETICUTILS_H |
| 22 | +#define MLIR_EXECUTIONENGINE_SPARSETENSOR_ARITHMETICUTILS_H |
| 23 | + |
| 24 | +#include <cassert> |
| 25 | +#include <cinttypes> |
| 26 | +#include <limits> |
| 27 | + |
| 28 | +namespace mlir { |
| 29 | +namespace sparse_tensor { |
| 30 | +namespace detail { |
| 31 | + |
| 32 | +//===----------------------------------------------------------------------===// |
| 33 | +// |
| 34 | +// Safe comparison functions. |
| 35 | +// |
| 36 | +// Variants of the `==`, `!=`, `<`, `<=`, `>`, and `>=` operators which |
| 37 | +// are careful to ensure that negatives are always considered strictly |
| 38 | +// less than non-negatives regardless of the signedness of the types of |
| 39 | +// the two arguments. They are "safe" in that they guarantee to *always* |
| 40 | +// give an output and that that output is correct; in particular this means |
| 41 | +// they never use assertions or other mechanisms for "returning an error". |
| 42 | +// |
| 43 | +// These functions are C++17-compatible backports of the safe comparison |
| 44 | +// functions added in C++20, and the implementations are based on the |
| 45 | +// sample implementations provided by the standard: |
| 46 | +// <https://en.cppreference.com/w/cpp/utility/intcmp>. |
| 47 | +// |
| 48 | +//===----------------------------------------------------------------------===// |
| 49 | + |
| 50 | +template <typename T, typename U> |
| 51 | +constexpr bool safelyEQ(T t, U u) noexcept { |
| 52 | + using UT = std::make_unsigned_t<T>; |
| 53 | + using UU = std::make_unsigned_t<U>; |
| 54 | + if constexpr (std::is_signed_v<T> == std::is_signed_v<U>) |
| 55 | + return t == u; |
| 56 | + else if constexpr (std::is_signed_v<T>) |
| 57 | + return t < 0 ? false : static_cast<UT>(t) == u; |
| 58 | + else |
| 59 | + return u < 0 ? false : t == static_cast<UU>(u); |
| 60 | +} |
| 61 | + |
| 62 | +template <typename T, typename U> |
| 63 | +constexpr bool safelyNE(T t, U u) noexcept { |
| 64 | + return !safelyEQ(t, u); |
| 65 | +} |
| 66 | + |
| 67 | +template <typename T, typename U> |
| 68 | +constexpr bool safelyLT(T t, U u) noexcept { |
| 69 | + using UT = std::make_unsigned_t<T>; |
| 70 | + using UU = std::make_unsigned_t<U>; |
| 71 | + if constexpr (std::is_signed_v<T> == std::is_signed_v<U>) |
| 72 | + return t < u; |
| 73 | + else if constexpr (std::is_signed_v<T>) |
| 74 | + return t < 0 ? true : static_cast<UT>(t) < u; |
| 75 | + else |
| 76 | + return u < 0 ? false : t < static_cast<UU>(u); |
| 77 | +} |
| 78 | + |
| 79 | +template <typename T, typename U> |
| 80 | +constexpr bool safelyGT(T t, U u) noexcept { |
| 81 | + return safelyLT(u, t); |
| 82 | +} |
| 83 | + |
| 84 | +template <typename T, typename U> |
| 85 | +constexpr bool safelyLE(T t, U u) noexcept { |
| 86 | + return !safelyGT(t, u); |
| 87 | +} |
| 88 | + |
| 89 | +template <typename T, typename U> |
| 90 | +constexpr bool safelyGE(T t, U u) noexcept { |
| 91 | + return !safelyLT(t, u); |
| 92 | +} |
| 93 | + |
| 94 | +//===----------------------------------------------------------------------===// |
| 95 | +// |
| 96 | +// Overflow checking functions. |
| 97 | +// |
| 98 | +// These functions use assertions to ensure correctness with respect to |
| 99 | +// overflow/underflow. Unlike the "safe" functions above, these "checked" |
| 100 | +// functions only guarantee that *if* they return an answer then that answer |
| 101 | +// is correct. When assertions are enabled, they do their best to remain |
| 102 | +// as fast as possible (since MLIR keeps assertions enabled by default, |
| 103 | +// even for optimized builds). When assertions are disabled, they use the |
| 104 | +// standard unchecked implementations. |
| 105 | +// |
| 106 | +//===----------------------------------------------------------------------===// |
| 107 | + |
| 108 | +// TODO: we would like to be able to pass in custom error messages, to |
| 109 | +// improve the user experience. We should be able to use something like |
| 110 | +// `assert(((void)(msg ? msg : defaultMsg), cond))`; but I'm not entirely |
| 111 | +// sure that'll work as intended when done within a function-definition |
| 112 | +// rather than within a macro-definition. |
| 113 | + |
| 114 | +/// A version of `static_cast<To>` which checks for overflow/underflow. |
| 115 | +/// The implementation avoids performing runtime assertions whenever |
| 116 | +/// the types alone are sufficient to statically prove that overflow |
| 117 | +/// cannot happen. |
| 118 | +template <typename To, typename From> |
| 119 | +[[nodiscard]] inline To checkOverflowCast(From x) { |
| 120 | + // Check the lower bound. (For when casting from signed types.) |
| 121 | + constexpr To minTo = std::numeric_limits<To>::min(); |
| 122 | + constexpr From minFrom = std::numeric_limits<From>::min(); |
| 123 | + if constexpr (!safelyGE(minFrom, minTo)) |
| 124 | + assert(safelyGE(x, minTo) && "cast would underflow"); |
| 125 | + // Check the upper bound. |
| 126 | + constexpr To maxTo = std::numeric_limits<To>::max(); |
| 127 | + constexpr From maxFrom = std::numeric_limits<From>::max(); |
| 128 | + if constexpr (!safelyLE(maxFrom, maxTo)) |
| 129 | + assert(safelyLE(x, maxTo) && "cast would overflow"); |
| 130 | + // Now do the cast itself. |
| 131 | + return static_cast<To>(x); |
| 132 | +} |
| 133 | + |
| 134 | +// TODO: would be better to use various architectures' intrinsics to |
| 135 | +// detect the overflow directly, instead of doing the assertion beforehand |
| 136 | +// (which requires an expensive division). |
| 137 | +// |
| 138 | +/// A version of `operator*` on `uint64_t` which guards against overflows |
| 139 | +/// (when assertions are enabled). |
| 140 | +inline uint64_t checkedMul(uint64_t lhs, uint64_t rhs) { |
| 141 | + assert((lhs == 0 || rhs <= std::numeric_limits<uint64_t>::max() / lhs) && |
| 142 | + "Integer overflow"); |
| 143 | + return lhs * rhs; |
| 144 | +} |
| 145 | + |
| 146 | +} // namespace detail |
| 147 | +} // namespace sparse_tensor |
| 148 | +} // namespace mlir |
| 149 | + |
| 150 | +#endif // MLIR_EXECUTIONENGINE_SPARSETENSOR_ARITHMETICUTILS_H |
0 commit comments