Skip to content

Commit 9d86299

Browse files
committed
[mlir][sparse] Adding safe comparison functions to MLIRSparseTensorRuntime.
Different platforms use different signedness for `StridedMemRefType::sizes` and `std::vector::size_type`, and this has been causing a lot of portability issues re [-Wsign-compare] warnings. These new functions ensure that we need never worry about those signedness warnings ever again. Also merging CheckedMul.h into ArithmeticUtils.h Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D138149
1 parent d62d278 commit 9d86299

File tree

4 files changed

+162
-59
lines changed

4 files changed

+162
-59
lines changed
Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
//===- ArithmeticUtils.h - Arithmetic helper functions ----------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This header is not part of the public API. It is placed in the
10+
// includes directory only because that's required by the implementations
11+
// of template-classes.
12+
//
13+
// This file is part of the lightweight runtime support library for sparse
14+
// tensor manipulations. The functionality of the support library is meant
15+
// to simplify benchmarking, testing, and debugging MLIR code operating on
16+
// sparse tensors. However, the provided functionality is **not** part of
17+
// core MLIR itself.
18+
//
19+
//===----------------------------------------------------------------------===//
20+
21+
#ifndef MLIR_EXECUTIONENGINE_SPARSETENSOR_ARITHMETICUTILS_H
22+
#define MLIR_EXECUTIONENGINE_SPARSETENSOR_ARITHMETICUTILS_H
23+
24+
#include <cassert>
25+
#include <cinttypes>
26+
#include <limits>
27+
28+
namespace mlir {
29+
namespace sparse_tensor {
30+
namespace detail {
31+
32+
//===----------------------------------------------------------------------===//
33+
//
34+
// Safe comparison functions.
35+
//
36+
// Variants of the `==`, `!=`, `<`, `<=`, `>`, and `>=` operators which
37+
// are careful to ensure that negatives are always considered strictly
38+
// less than non-negatives regardless of the signedness of the types of
39+
// the two arguments. They are "safe" in that they guarantee to *always*
40+
// give an output and that that output is correct; in particular this means
41+
// they never use assertions or other mechanisms for "returning an error".
42+
//
43+
// These functions are C++17-compatible backports of the safe comparison
44+
// functions added in C++20, and the implementations are based on the
45+
// sample implementations provided by the standard:
46+
// <https://en.cppreference.com/w/cpp/utility/intcmp>.
47+
//
48+
//===----------------------------------------------------------------------===//
49+
50+
template <typename T, typename U>
51+
constexpr bool safelyEQ(T t, U u) noexcept {
52+
using UT = std::make_unsigned_t<T>;
53+
using UU = std::make_unsigned_t<U>;
54+
if constexpr (std::is_signed_v<T> == std::is_signed_v<U>)
55+
return t == u;
56+
else if constexpr (std::is_signed_v<T>)
57+
return t < 0 ? false : static_cast<UT>(t) == u;
58+
else
59+
return u < 0 ? false : t == static_cast<UU>(u);
60+
}
61+
62+
template <typename T, typename U>
63+
constexpr bool safelyNE(T t, U u) noexcept {
64+
return !safelyEQ(t, u);
65+
}
66+
67+
template <typename T, typename U>
68+
constexpr bool safelyLT(T t, U u) noexcept {
69+
using UT = std::make_unsigned_t<T>;
70+
using UU = std::make_unsigned_t<U>;
71+
if constexpr (std::is_signed_v<T> == std::is_signed_v<U>)
72+
return t < u;
73+
else if constexpr (std::is_signed_v<T>)
74+
return t < 0 ? true : static_cast<UT>(t) < u;
75+
else
76+
return u < 0 ? false : t < static_cast<UU>(u);
77+
}
78+
79+
template <typename T, typename U>
80+
constexpr bool safelyGT(T t, U u) noexcept {
81+
return safelyLT(u, t);
82+
}
83+
84+
template <typename T, typename U>
85+
constexpr bool safelyLE(T t, U u) noexcept {
86+
return !safelyGT(t, u);
87+
}
88+
89+
template <typename T, typename U>
90+
constexpr bool safelyGE(T t, U u) noexcept {
91+
return !safelyLT(t, u);
92+
}
93+
94+
//===----------------------------------------------------------------------===//
95+
//
96+
// Overflow checking functions.
97+
//
98+
// These functions use assertions to ensure correctness with respect to
99+
// overflow/underflow. Unlike the "safe" functions above, these "checked"
100+
// functions only guarantee that *if* they return an answer then that answer
101+
// is correct. When assertions are enabled, they do their best to remain
102+
// as fast as possible (since MLIR keeps assertions enabled by default,
103+
// even for optimized builds). When assertions are disabled, they use the
104+
// standard unchecked implementations.
105+
//
106+
//===----------------------------------------------------------------------===//
107+
108+
// TODO: we would like to be able to pass in custom error messages, to
109+
// improve the user experience. We should be able to use something like
110+
// `assert(((void)(msg ? msg : defaultMsg), cond))`; but I'm not entirely
111+
// sure that'll work as intended when done within a function-definition
112+
// rather than within a macro-definition.
113+
114+
/// A version of `static_cast<To>` which checks for overflow/underflow.
115+
/// The implementation avoids performing runtime assertions whenever
116+
/// the types alone are sufficient to statically prove that overflow
117+
/// cannot happen.
118+
template <typename To, typename From>
119+
[[nodiscard]] inline To checkOverflowCast(From x) {
120+
// Check the lower bound. (For when casting from signed types.)
121+
constexpr To minTo = std::numeric_limits<To>::min();
122+
constexpr From minFrom = std::numeric_limits<From>::min();
123+
if constexpr (!safelyGE(minFrom, minTo))
124+
assert(safelyGE(x, minTo) && "cast would underflow");
125+
// Check the upper bound.
126+
constexpr To maxTo = std::numeric_limits<To>::max();
127+
constexpr From maxFrom = std::numeric_limits<From>::max();
128+
if constexpr (!safelyLE(maxFrom, maxTo))
129+
assert(safelyLE(x, maxTo) && "cast would overflow");
130+
// Now do the cast itself.
131+
return static_cast<To>(x);
132+
}
133+
134+
// TODO: would be better to use various architectures' intrinsics to
135+
// detect the overflow directly, instead of doing the assertion beforehand
136+
// (which requires an expensive division).
137+
//
138+
/// A version of `operator*` on `uint64_t` which guards against overflows
139+
/// (when assertions are enabled).
140+
inline uint64_t checkedMul(uint64_t lhs, uint64_t rhs) {
141+
assert((lhs == 0 || rhs <= std::numeric_limits<uint64_t>::max() / lhs) &&
142+
"Integer overflow");
143+
return lhs * rhs;
144+
}
145+
146+
} // namespace detail
147+
} // namespace sparse_tensor
148+
} // namespace mlir
149+
150+
#endif // MLIR_EXECUTIONENGINE_SPARSETENSOR_ARITHMETICUTILS_H

mlir/include/mlir/ExecutionEngine/SparseTensor/CheckedMul.h

Lines changed: 0 additions & 48 deletions
This file was deleted.

mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,9 @@
3535

3636
#include "mlir/Dialect/SparseTensor/IR/Enums.h"
3737
#include "mlir/ExecutionEngine/Float16bits.h"
38+
#include "mlir/ExecutionEngine/SparseTensor/ArithmeticUtils.h"
3839
#include "mlir/ExecutionEngine/SparseTensor/Attributes.h"
3940
#include "mlir/ExecutionEngine/SparseTensor/COO.h"
40-
#include "mlir/ExecutionEngine/SparseTensor/CheckedMul.h"
4141
#include "mlir/ExecutionEngine/SparseTensor/ErrorHandling.h"
4242

4343
namespace mlir {
@@ -509,9 +509,10 @@ class SparseTensorStorage final : public SparseTensorStorageBase {
509509
/// the previous position and smaller than `indices[l].capacity()`).
510510
void appendPointer(uint64_t l, uint64_t pos, uint64_t count = 1) {
511511
ASSERT_COMPRESSED_LVL(l);
512-
assert(pos <= std::numeric_limits<P>::max() &&
513-
"Pointer value is too large for the P-type");
514-
pointers[l].insert(pointers[l].end(), count, static_cast<P>(pos));
512+
// TODO: we'd like to recover the nicer error message:
513+
// "Pointer value is too large for the P-type"
514+
pointers[l].insert(pointers[l].end(), count,
515+
detail::checkOverflowCast<P>(pos));
515516
}
516517

517518
/// Appends index `i` to level `l`, in the semantically general sense.
@@ -526,9 +527,9 @@ class SparseTensorStorage final : public SparseTensorStorageBase {
526527
void appendIndex(uint64_t l, uint64_t full, uint64_t i) {
527528
const auto dlt = getLvlType(l); // Avoid redundant bounds checking.
528529
if (isCompressedDLT(dlt) || isSingletonDLT(dlt)) {
529-
assert(i <= std::numeric_limits<I>::max() &&
530-
"Index value is too large for the I-type");
531-
indices[l].push_back(static_cast<I>(i));
530+
// TODO: we'd like to recover the nicer error message:
531+
// "Index value is too large for the I-type"
532+
indices[l].push_back(detail::checkOverflowCast<I>(i));
532533
} else { // Dense dimension.
533534
ASSERT_DENSE_DLT(dlt);
534535
assert(i >= full && "Index was already filled");
@@ -551,9 +552,9 @@ class SparseTensorStorage final : public SparseTensorStorageBase {
551552
// entry has been initialized; thus we must be sure to check `size()`
552553
// here, instead of `capacity()` as would be ideal.
553554
assert(pos < indices[l].size() && "Index position is out of bounds");
554-
assert(i <= std::numeric_limits<I>::max() &&
555-
"Index value is too large for the I-type");
556-
indices[l][pos] = static_cast<I>(i);
555+
// TODO: we'd like to recover the nicer error message:
556+
// "Index value is too large for the I-type"
557+
indices[l][pos] = detail::checkOverflowCast<I>(i);
557558
}
558559

559560
/// Computes the assembled-size associated with the `l`-th level,

utils/bazel/llvm-project-overlay/mlir/BUILD.bazel

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6900,9 +6900,9 @@ cc_library(
69006900
"lib/ExecutionEngine/SparseTensor/Storage.cpp",
69016901
],
69026902
hdrs = [
6903+
"include/mlir/ExecutionEngine/SparseTensor/ArithmeticUtils.h",
69036904
"include/mlir/ExecutionEngine/SparseTensor/Attributes.h",
69046905
"include/mlir/ExecutionEngine/SparseTensor/COO.h",
6905-
"include/mlir/ExecutionEngine/SparseTensor/CheckedMul.h",
69066906
"include/mlir/ExecutionEngine/SparseTensor/ErrorHandling.h",
69076907
"include/mlir/ExecutionEngine/SparseTensor/File.h",
69086908
"include/mlir/ExecutionEngine/SparseTensor/PermutationRef.h",

0 commit comments

Comments
 (0)