diff --git a/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h b/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h index f86535740fec9..60bfdfa322120 100644 --- a/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h +++ b/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h @@ -511,6 +511,31 @@ class IntegerRelation { void projectOut(unsigned pos, unsigned num); inline void projectOut(unsigned pos) { return projectOut(pos, 1); } + /// The set of constraints (equations/inequalities) can be modeled as an + /// undirected graph where: + /// 1. Variables are the nodes. + /// 2. Constraints are the edges connecting those nodes. + /// + /// Variables and constraints belonging to different connected components + /// are irrelevant to each other. This property allows for safe pruning of + /// constraints. + /// + /// For example, given the following constraints: + /// - Inequalities: (1) d0 + d1 > 0, (2) d1 >= 2, (3) d4 > 5 + /// - Equalities: (4) d3 + d4 = 1, (5) d0 - d2 = 3 + /// + /// These form two connected components: + /// - Component 1: {d0, d1, d2} (related by constraints 1, 2, 5) + /// - Component 2: {d3, d4} (related by constraint 4) + /// + /// If we are querying the bound of variable `d0`, constraints related to + /// Component 2 (e.g., constraints 3 and 4) can be safely pruned as they + /// have no impact on the solution space of Component 1. + /// This function prunes irrelevant constraints by identifying all variables + /// and constraints that belong to the same connected component as the + /// target variable. + void pruneConstraints(unsigned pos); + /// Tries to fold the specified variable to a constant using a trivial /// equality detection; if successful, the constant is substituted for the /// variable everywhere in the constraint system and then removed from the diff --git a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp index 0dcdd5bb97bc8..0354129ddf845 100644 --- a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp +++ b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp @@ -21,6 +21,7 @@ #include "mlir/Analysis/Presburger/Simplex.h" #include "mlir/Analysis/Presburger/Utils.h" #include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/DenseSet.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/Sequence.h" #include "llvm/ADT/SmallBitVector.h" @@ -1723,12 +1724,78 @@ std::optional IntegerRelation::getConstantBoundOnDimSize( return minDiff; } +void IntegerRelation::pruneConstraints(unsigned pos) { + llvm::DenseSet relatedCols({pos}), relatedRows; + + // Early quit if constraints is empty. + unsigned numConstraints = getNumConstraints(); + if (numConstraints == 0) + return; + + llvm::SmallVector rowStack, colStack({pos}); + // The following code performs a graph traversal, starting from the target + // variable, to identify all variables(recorded in relatedCols) and + // constraints(recorded in relatedRows) belonging to the same connected + // component. + while (!rowStack.empty() || !colStack.empty()) { + if (!rowStack.empty()) { + unsigned currentRow = rowStack.pop_back_val(); + // Push all variable that accociated to this constrain to relatedCols + // and colStack. + for (uint64_t colIndex = 0; colIndex < getNumVars(); ++colIndex) { + if (currentRow < getNumInequalities()) { + if (atIneq(currentRow, colIndex) != 0 && + relatedCols.insert(colIndex).second) { + colStack.push_back(colIndex); + } + } else { + if (atEq(currentRow - getNumInequalities(), colIndex) != 0 && + relatedCols.insert(colIndex).second) { + colStack.push_back(colIndex); + } + } + } + } else { + unsigned currentCol = colStack.pop_back_val(); + // Push all constrains that accociated to this variable to relatedRows + // and rowStack. + for (uint64_t rowIndex = 0; rowIndex < numConstraints; ++rowIndex) { + if (rowIndex < getNumInequalities()) { + if (atIneq(rowIndex, currentCol) != 0 && + relatedRows.insert(rowIndex).second) { + rowStack.push_back(rowIndex); + } + } else { + if (atEq(rowIndex - getNumInequalities(), currentCol) != 0 && + relatedRows.insert(rowIndex).second) { + rowStack.push_back(rowIndex); + } + } + } + } + } + + // Prune all constraints not related to target variable. + for (int64_t constraintId = numConstraints - 1; constraintId >= 0; + --constraintId) { + if (!relatedRows.contains(constraintId)) { + if (constraintId >= getNumInequalities()) { + removeEquality(constraintId - getNumInequalities()); + } else { + removeInequality(constraintId); + } + } + } +} + template std::optional IntegerRelation::computeConstantLowerOrUpperBound(unsigned pos) { assert(pos < getNumVars() && "invalid position"); // Project to 'pos'. + pruneConstraints(pos); projectOut(0, pos); + pruneConstraints(0); projectOut(1, getNumVars() - 1); // Check if there's an equality equating the '0'^th variable to a constant. int eqRowIdx = findEqualityToConstant(/*pos=*/0, /*symbolic=*/false);