From 4deef7b149eccc879902633226dd88f0c9859040 Mon Sep 17 00:00:00 2001 From: donald chen Date: Tue, 14 Oct 2025 13:50:31 +0000 Subject: [PATCH] [mlir][presburger] Optimize the compilation time for calculating bounds of an Integer Relation IntegerRelation uses Fourier-Motzkin elimination and Gaussian elimination to simplify constraints. These methods may repeatedly perform calculations and elimination on irrelevant variables. Preemptively eliminating irrelevant variables and their associated constraints can speed up up the calculation process. --- .../Analysis/Presburger/IntegerRelation.h | 3 ++ .../Analysis/Presburger/IntegerRelation.cpp | 53 +++++++++++++++++++ 2 files changed, 56 insertions(+) diff --git a/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h b/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h index f86535740fec9..026d84529edfb 100644 --- a/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h +++ b/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h @@ -511,6 +511,9 @@ class IntegerRelation { void projectOut(unsigned pos, unsigned num); inline void projectOut(unsigned pos) { return projectOut(pos, 1); } + /// Prune constraints that are irrelevant to the target variable. + void pruneConstraints(unsigned pos); + /// Tries to fold the specified variable to a constant using a trivial /// equality detection; if successful, the constant is substituted for the /// variable everywhere in the constraint system and then removed from the diff --git a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp index 0dcdd5bb97bc8..1352c1b2da40c 100644 --- a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp +++ b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp @@ -1723,12 +1723,65 @@ std::optional IntegerRelation::getConstantBoundOnDimSize( return minDiff; } +void IntegerRelation::pruneConstraints(unsigned pos) { + llvm::DenseSet relatedCols({pos}), relatedRows; + + llvm::SmallVector rowStack, colStack({pos}); + unsigned numConstraints = getNumConstraints(); + while (!rowStack.empty() || !colStack.empty()) { + if (!rowStack.empty()) { + unsigned currentRow = rowStack.pop_back_val(); + for (uint64_t colIndex = 0; colIndex < getNumVars(); ++colIndex) { + if (currentRow < getNumInequalities()) { + if (atIneq(currentRow, colIndex) != 0 && + relatedCols.insert(colIndex).second) { + colStack.push_back(colIndex); + } + } else { + if (atEq(currentRow - getNumInequalities(), colIndex) != 0 && + relatedCols.insert(colIndex).second) { + colStack.push_back(colIndex); + } + } + } + } else { + unsigned currentCol = colStack.pop_back_val(); + for (uint64_t rowIndex = 0; rowIndex < numConstraints; ++rowIndex) { + if (rowIndex < getNumInequalities()) { + if (atIneq(rowIndex, currentCol) != 0 && + relatedRows.insert(rowIndex).second) { + rowStack.push_back(rowIndex); + } + } else { + if (atEq(rowIndex - getNumInequalities(), currentCol) != 0 && + relatedRows.insert(rowIndex).second) { + rowStack.push_back(rowIndex); + } + } + } + } + } + + for (int64_t constraintId = numConstraints - 1; constraintId >= 0; + --constraintId) { + if (!relatedRows.contains(constraintId)) { + if (constraintId >= getNumInequalities()) { + removeEquality(constraintId - getNumInequalities()); + } else { + removeInequality(constraintId); + } + } + } +} + template std::optional IntegerRelation::computeConstantLowerOrUpperBound(unsigned pos) { assert(pos < getNumVars() && "invalid position"); // Project to 'pos'. + pruneConstraints(pos); projectOut(0, pos); + pruneConstraints(0); projectOut(1, getNumVars() - 1); // Check if there's an equality equating the '0'^th variable to a constant. int eqRowIdx = findEqualityToConstant(/*pos=*/0, /*symbolic=*/false);