|
21 | 21 | #include "mlir/Analysis/Presburger/Simplex.h" |
22 | 22 | #include "mlir/Analysis/Presburger/Utils.h" |
23 | 23 | #include "llvm/ADT/DenseMap.h" |
| 24 | +#include "llvm/ADT/DenseSet.h" |
24 | 25 | #include "llvm/ADT/STLExtras.h" |
25 | 26 | #include "llvm/ADT/Sequence.h" |
26 | 27 | #include "llvm/ADT/SmallBitVector.h" |
@@ -1723,12 +1724,78 @@ std::optional<DynamicAPInt> IntegerRelation::getConstantBoundOnDimSize( |
1723 | 1724 | return minDiff; |
1724 | 1725 | } |
1725 | 1726 |
|
| 1727 | +void IntegerRelation::pruneConstraints(unsigned pos) { |
| 1728 | + llvm::DenseSet<unsigned> relatedCols({pos}), relatedRows; |
| 1729 | + |
| 1730 | + // Early quit if constraints is empty. |
| 1731 | + unsigned numConstraints = getNumConstraints(); |
| 1732 | + if (numConstraints == 0) |
| 1733 | + return; |
| 1734 | + |
| 1735 | + llvm::SmallVector<unsigned> rowStack, colStack({pos}); |
| 1736 | + // The following code performs a graph traversal, starting from the target |
| 1737 | + // variable, to identify all variables(recorded in relatedCols) and |
| 1738 | + // constraints(recorded in relatedRows) belonging to the same connected |
| 1739 | + // component. |
| 1740 | + while (!rowStack.empty() || !colStack.empty()) { |
| 1741 | + if (!rowStack.empty()) { |
| 1742 | + unsigned currentRow = rowStack.pop_back_val(); |
| 1743 | + // Push all variable that accociated to this constrain to relatedCols |
| 1744 | + // and colStack. |
| 1745 | + for (uint64_t colIndex = 0; colIndex < getNumVars(); ++colIndex) { |
| 1746 | + if (currentRow < getNumInequalities()) { |
| 1747 | + if (atIneq(currentRow, colIndex) != 0 && |
| 1748 | + relatedCols.insert(colIndex).second) { |
| 1749 | + colStack.push_back(colIndex); |
| 1750 | + } |
| 1751 | + } else { |
| 1752 | + if (atEq(currentRow - getNumInequalities(), colIndex) != 0 && |
| 1753 | + relatedCols.insert(colIndex).second) { |
| 1754 | + colStack.push_back(colIndex); |
| 1755 | + } |
| 1756 | + } |
| 1757 | + } |
| 1758 | + } else { |
| 1759 | + unsigned currentCol = colStack.pop_back_val(); |
| 1760 | + // Push all constrains that accociated to this variable to relatedRows |
| 1761 | + // and rowStack. |
| 1762 | + for (uint64_t rowIndex = 0; rowIndex < numConstraints; ++rowIndex) { |
| 1763 | + if (rowIndex < getNumInequalities()) { |
| 1764 | + if (atIneq(rowIndex, currentCol) != 0 && |
| 1765 | + relatedRows.insert(rowIndex).second) { |
| 1766 | + rowStack.push_back(rowIndex); |
| 1767 | + } |
| 1768 | + } else { |
| 1769 | + if (atEq(rowIndex - getNumInequalities(), currentCol) != 0 && |
| 1770 | + relatedRows.insert(rowIndex).second) { |
| 1771 | + rowStack.push_back(rowIndex); |
| 1772 | + } |
| 1773 | + } |
| 1774 | + } |
| 1775 | + } |
| 1776 | + } |
| 1777 | + |
| 1778 | + // Prune all constraints not related to target variable. |
| 1779 | + for (int64_t constraintId = numConstraints - 1; constraintId >= 0; |
| 1780 | + --constraintId) { |
| 1781 | + if (!relatedRows.contains(constraintId)) { |
| 1782 | + if (constraintId >= getNumInequalities()) { |
| 1783 | + removeEquality(constraintId - getNumInequalities()); |
| 1784 | + } else { |
| 1785 | + removeInequality(constraintId); |
| 1786 | + } |
| 1787 | + } |
| 1788 | + } |
| 1789 | +} |
| 1790 | + |
1726 | 1791 | template <bool isLower> |
1727 | 1792 | std::optional<DynamicAPInt> |
1728 | 1793 | IntegerRelation::computeConstantLowerOrUpperBound(unsigned pos) { |
1729 | 1794 | assert(pos < getNumVars() && "invalid position"); |
1730 | 1795 | // Project to 'pos'. |
| 1796 | + pruneConstraints(pos); |
1731 | 1797 | projectOut(0, pos); |
| 1798 | + pruneConstraints(0); |
1732 | 1799 | projectOut(1, getNumVars() - 1); |
1733 | 1800 | // Check if there's an equality equating the '0'^th variable to a constant. |
1734 | 1801 | int eqRowIdx = findEqualityToConstant(/*pos=*/0, /*symbolic=*/false); |
|
0 commit comments