Skip to content

Commit 5e01741

Browse files
committed
[mlir][presburger] Optimize the compilation time for calculating bounds of an Integer Relation
IntegerRelation uses Fourier-Motzkin elimination and Gaussian elimination to simplify constraints. These methods may repeatedly perform calculations and elimination on irrelevant variables. Preemptively eliminating irrelevant variables and their associated constraints can speed up up the calculation process.
1 parent 45495b5 commit 5e01741

File tree

2 files changed

+92
-0
lines changed

2 files changed

+92
-0
lines changed

mlir/include/mlir/Analysis/Presburger/IntegerRelation.h

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -511,6 +511,31 @@ class IntegerRelation {
511511
void projectOut(unsigned pos, unsigned num);
512512
inline void projectOut(unsigned pos) { return projectOut(pos, 1); }
513513

514+
/// The set of constraints (equations/inequalities) can be modeled as an
515+
/// undirected graph where:
516+
/// 1. Variables are the nodes.
517+
/// 2. Constraints are the edges connecting those nodes.
518+
///
519+
/// Variables and constraints belonging to different connected components
520+
/// are irrelevant to each other. This property allows for safe pruning of
521+
/// constraints.
522+
///
523+
/// For example, given the following constraints:
524+
/// - Inequalities: (1) d0 + d1 > 0, (2) d1 >= 2, (3) d4 > 5
525+
/// - Equalities: (4) d3 + d4 = 1, (5) d0 - d2 = 3
526+
///
527+
/// These form two connected components:
528+
/// - Component 1: {d0, d1, d2} (related by constraints 1, 2, 5)
529+
/// - Component 2: {d3, d4} (related by constraint 4)
530+
///
531+
/// If we are querying the bound of variable `d0`, constraints related to
532+
/// Component 2 (e.g., constraints 3 and 4) can be safely pruned as they
533+
/// have no impact on the solution space of Component 1.
534+
/// This function prunes irrelevant constraints by identifying all variables
535+
/// and constraints that belong to the same connected component as the
536+
/// target variable.
537+
void pruneConstraints(unsigned pos);
538+
514539
/// Tries to fold the specified variable to a constant using a trivial
515540
/// equality detection; if successful, the constant is substituted for the
516541
/// variable everywhere in the constraint system and then removed from the

mlir/lib/Analysis/Presburger/IntegerRelation.cpp

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include "mlir/Analysis/Presburger/Simplex.h"
2222
#include "mlir/Analysis/Presburger/Utils.h"
2323
#include "llvm/ADT/DenseMap.h"
24+
#include "llvm/ADT/DenseSet.h"
2425
#include "llvm/ADT/STLExtras.h"
2526
#include "llvm/ADT/Sequence.h"
2627
#include "llvm/ADT/SmallBitVector.h"
@@ -1723,12 +1724,78 @@ std::optional<DynamicAPInt> IntegerRelation::getConstantBoundOnDimSize(
17231724
return minDiff;
17241725
}
17251726

1727+
void IntegerRelation::pruneConstraints(unsigned pos) {
1728+
llvm::DenseSet<unsigned> relatedCols({pos}), relatedRows;
1729+
1730+
// Early quit if constraints is empty.
1731+
unsigned numConstraints = getNumConstraints();
1732+
if (numConstraints == 0)
1733+
return;
1734+
1735+
llvm::SmallVector<unsigned> rowStack, colStack({pos});
1736+
// The following code performs a graph traversal, starting from the target
1737+
// variable, to identify all variables(recorded in relatedCols) and
1738+
// constraints(recorded in relatedRows) belonging to the same connected
1739+
// component.
1740+
while (!rowStack.empty() || !colStack.empty()) {
1741+
if (!rowStack.empty()) {
1742+
unsigned currentRow = rowStack.pop_back_val();
1743+
// Push all variable that accociated to this constrain to relatedCols
1744+
// and colStack.
1745+
for (uint64_t colIndex = 0; colIndex < getNumVars(); ++colIndex) {
1746+
if (currentRow < getNumInequalities()) {
1747+
if (atIneq(currentRow, colIndex) != 0 &&
1748+
relatedCols.insert(colIndex).second) {
1749+
colStack.push_back(colIndex);
1750+
}
1751+
} else {
1752+
if (atEq(currentRow - getNumInequalities(), colIndex) != 0 &&
1753+
relatedCols.insert(colIndex).second) {
1754+
colStack.push_back(colIndex);
1755+
}
1756+
}
1757+
}
1758+
} else {
1759+
unsigned currentCol = colStack.pop_back_val();
1760+
// Push all constrains that accociated to this variable to relatedRows
1761+
// and rowStack.
1762+
for (uint64_t rowIndex = 0; rowIndex < numConstraints; ++rowIndex) {
1763+
if (rowIndex < getNumInequalities()) {
1764+
if (atIneq(rowIndex, currentCol) != 0 &&
1765+
relatedRows.insert(rowIndex).second) {
1766+
rowStack.push_back(rowIndex);
1767+
}
1768+
} else {
1769+
if (atEq(rowIndex - getNumInequalities(), currentCol) != 0 &&
1770+
relatedRows.insert(rowIndex).second) {
1771+
rowStack.push_back(rowIndex);
1772+
}
1773+
}
1774+
}
1775+
}
1776+
}
1777+
1778+
// Prune all constraints not related to target variable.
1779+
for (int64_t constraintId = numConstraints - 1; constraintId >= 0;
1780+
--constraintId) {
1781+
if (!relatedRows.contains(constraintId)) {
1782+
if (constraintId >= getNumInequalities()) {
1783+
removeEquality(constraintId - getNumInequalities());
1784+
} else {
1785+
removeInequality(constraintId);
1786+
}
1787+
}
1788+
}
1789+
}
1790+
17261791
template <bool isLower>
17271792
std::optional<DynamicAPInt>
17281793
IntegerRelation::computeConstantLowerOrUpperBound(unsigned pos) {
17291794
assert(pos < getNumVars() && "invalid position");
17301795
// Project to 'pos'.
1796+
pruneConstraints(pos);
17311797
projectOut(0, pos);
1798+
pruneConstraints(0);
17321799
projectOut(1, getNumVars() - 1);
17331800
// Check if there's an equality equating the '0'^th variable to a constant.
17341801
int eqRowIdx = findEqualityToConstant(/*pos=*/0, /*symbolic=*/false);

0 commit comments

Comments
 (0)