diff --git a/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h b/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h index f86535740fec9..026d84529edfb 100644 --- a/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h +++ b/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h @@ -511,6 +511,9 @@ class IntegerRelation { void projectOut(unsigned pos, unsigned num); inline void projectOut(unsigned pos) { return projectOut(pos, 1); } + /// Prune constraints that are irrelevant to the target variable. + void pruneConstraints(unsigned pos); + /// Tries to fold the specified variable to a constant using a trivial /// equality detection; if successful, the constant is substituted for the /// variable everywhere in the constraint system and then removed from the diff --git a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp index 0dcdd5bb97bc8..1352c1b2da40c 100644 --- a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp +++ b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp @@ -1723,12 +1723,65 @@ std::optional IntegerRelation::getConstantBoundOnDimSize( return minDiff; } +void IntegerRelation::pruneConstraints(unsigned pos) { + llvm::DenseSet relatedCols({pos}), relatedRows; + + llvm::SmallVector rowStack, colStack({pos}); + unsigned numConstraints = getNumConstraints(); + while (!rowStack.empty() || !colStack.empty()) { + if (!rowStack.empty()) { + unsigned currentRow = rowStack.pop_back_val(); + for (uint64_t colIndex = 0; colIndex < getNumVars(); ++colIndex) { + if (currentRow < getNumInequalities()) { + if (atIneq(currentRow, colIndex) != 0 && + relatedCols.insert(colIndex).second) { + colStack.push_back(colIndex); + } + } else { + if (atEq(currentRow - getNumInequalities(), colIndex) != 0 && + relatedCols.insert(colIndex).second) { + colStack.push_back(colIndex); + } + } + } + } else { + unsigned currentCol = colStack.pop_back_val(); + for (uint64_t rowIndex = 0; rowIndex < numConstraints; ++rowIndex) { + if (rowIndex < getNumInequalities()) { + if (atIneq(rowIndex, currentCol) != 0 && + relatedRows.insert(rowIndex).second) { + rowStack.push_back(rowIndex); + } + } else { + if (atEq(rowIndex - getNumInequalities(), currentCol) != 0 && + relatedRows.insert(rowIndex).second) { + rowStack.push_back(rowIndex); + } + } + } + } + } + + for (int64_t constraintId = numConstraints - 1; constraintId >= 0; + --constraintId) { + if (!relatedRows.contains(constraintId)) { + if (constraintId >= getNumInequalities()) { + removeEquality(constraintId - getNumInequalities()); + } else { + removeInequality(constraintId); + } + } + } +} + template std::optional IntegerRelation::computeConstantLowerOrUpperBound(unsigned pos) { assert(pos < getNumVars() && "invalid position"); // Project to 'pos'. + pruneConstraints(pos); projectOut(0, pos); + pruneConstraints(0); projectOut(1, getNumVars() - 1); // Check if there's an equality equating the '0'^th variable to a constant. int eqRowIdx = findEqualityToConstant(/*pos=*/0, /*symbolic=*/false);