Skip to content

Commit ee2806f

Browse files
committed
[mlir][presburger] Optimize the compilation time for calculating bounds of an Integer Relation
IntegerRelation uses Fourier-Motzkin elimination and Gaussian elimination to simplify constraints. These methods may repeatedly perform calculations and elimination on irrelevant variables. Preemptively eliminating irrelevant variables and their associated constraints can speed up up the calculation process.
1 parent 45495b5 commit ee2806f

File tree

2 files changed

+59
-0
lines changed

2 files changed

+59
-0
lines changed

mlir/include/mlir/Analysis/Presburger/IntegerRelation.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -511,6 +511,9 @@ class IntegerRelation {
511511
void projectOut(unsigned pos, unsigned num);
512512
inline void projectOut(unsigned pos) { return projectOut(pos, 1); }
513513

514+
/// Prune constraints that are irrelevant to the target variable.
515+
void pruneConstraints(unsigned pos);
516+
514517
/// Tries to fold the specified variable to a constant using a trivial
515518
/// equality detection; if successful, the constant is substituted for the
516519
/// variable everywhere in the constraint system and then removed from the

mlir/lib/Analysis/Presburger/IntegerRelation.cpp

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include "mlir/Analysis/Presburger/Simplex.h"
2222
#include "mlir/Analysis/Presburger/Utils.h"
2323
#include "llvm/ADT/DenseMap.h"
24+
#include "llvm/ADT/DenseSet.h"
2425
#include "llvm/ADT/STLExtras.h"
2526
#include "llvm/ADT/Sequence.h"
2627
#include "llvm/ADT/SmallBitVector.h"
@@ -1723,12 +1724,67 @@ std::optional<DynamicAPInt> IntegerRelation::getConstantBoundOnDimSize(
17231724
return minDiff;
17241725
}
17251726

1727+
void IntegerRelation::pruneConstraints(unsigned pos) {
1728+
llvm::DenseSet<unsigned> relatedCols({pos}), relatedRows;
1729+
1730+
llvm::SmallVector<unsigned> rowStack, colStack({pos});
1731+
unsigned numConstraints = getNumConstraints();
1732+
if (numConstraints == 0)
1733+
return;
1734+
while (!rowStack.empty() || !colStack.empty()) {
1735+
if (!rowStack.empty()) {
1736+
unsigned currentRow = rowStack.pop_back_val();
1737+
for (uint64_t colIndex = 0; colIndex < getNumVars(); ++colIndex) {
1738+
if (currentRow < getNumInequalities()) {
1739+
if (atIneq(currentRow, colIndex) != 0 &&
1740+
relatedCols.insert(colIndex).second) {
1741+
colStack.push_back(colIndex);
1742+
}
1743+
} else {
1744+
if (atEq(currentRow - getNumInequalities(), colIndex) != 0 &&
1745+
relatedCols.insert(colIndex).second) {
1746+
colStack.push_back(colIndex);
1747+
}
1748+
}
1749+
}
1750+
} else {
1751+
unsigned currentCol = colStack.pop_back_val();
1752+
for (uint64_t rowIndex = 0; rowIndex < numConstraints; ++rowIndex) {
1753+
if (rowIndex < getNumInequalities()) {
1754+
if (atIneq(rowIndex, currentCol) != 0 &&
1755+
relatedRows.insert(rowIndex).second) {
1756+
rowStack.push_back(rowIndex);
1757+
}
1758+
} else {
1759+
if (atEq(rowIndex - getNumInequalities(), currentCol) != 0 &&
1760+
relatedRows.insert(rowIndex).second) {
1761+
rowStack.push_back(rowIndex);
1762+
}
1763+
}
1764+
}
1765+
}
1766+
}
1767+
1768+
for (int64_t constraintId = numConstraints - 1; constraintId >= 0;
1769+
--constraintId) {
1770+
if (!relatedRows.contains(constraintId)) {
1771+
if (constraintId >= getNumInequalities()) {
1772+
removeEquality(constraintId - getNumInequalities());
1773+
} else {
1774+
removeInequality(constraintId);
1775+
}
1776+
}
1777+
}
1778+
}
1779+
17261780
template <bool isLower>
17271781
std::optional<DynamicAPInt>
17281782
IntegerRelation::computeConstantLowerOrUpperBound(unsigned pos) {
17291783
assert(pos < getNumVars() && "invalid position");
17301784
// Project to 'pos'.
1785+
pruneConstraints(pos);
17311786
projectOut(0, pos);
1787+
pruneConstraints(0);
17321788
projectOut(1, getNumVars() - 1);
17331789
// Check if there's an equality equating the '0'^th variable to a constant.
17341790
int eqRowIdx = findEqualityToConstant(/*pos=*/0, /*symbolic=*/false);

0 commit comments

Comments
 (0)