-
Notifications
You must be signed in to change notification settings - Fork 15.3k
[MLIR][Presburger] optimize bound computation by pruning orthogonal constraints #164199
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -511,6 +511,31 @@ class IntegerRelation { | |
| void projectOut(unsigned pos, unsigned num); | ||
| inline void projectOut(unsigned pos) { return projectOut(pos, 1); } | ||
|
|
||
| /// The set of constraints (equations/inequalities) can be modeled as an | ||
| /// undirected graph where: | ||
| /// 1. Variables are the nodes. | ||
| /// 2. Constraints are the edges connecting those nodes. | ||
| /// | ||
| /// Variables and constraints belonging to different connected components | ||
| /// are irrelevant to each other. This property allows for safe pruning of | ||
| /// constraints. | ||
| /// | ||
| /// For example, given the following constraints: | ||
| /// - Inequalities: (1) d0 + d1 > 0, (2) d1 >= 2, (3) d4 > 5 | ||
| /// - Equalities: (4) d3 + d4 = 1, (5) d0 - d2 = 3 | ||
| /// | ||
| /// These form two connected components: | ||
| /// - Component 1: {d0, d1, d2} (related by constraints 1, 2, 5) | ||
| /// - Component 2: {d3, d4} (related by constraint 4) | ||
| /// | ||
| /// If we are querying the bound of variable `d0`, constraints related to | ||
| /// Component 2 (e.g., constraints 3 and 4) can be safely pruned as they | ||
| /// have no impact on the solution space of Component 1. | ||
| /// This function prunes irrelevant constraints by identifying all variables | ||
| /// and constraints that belong to the same connected component as the | ||
| /// target variable. | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we change the name to |
||
| void pruneConstraints(unsigned pos); | ||
|
|
||
| /// Tries to fold the specified variable to a constant using a trivial | ||
| /// equality detection; if successful, the constant is substituted for the | ||
| /// variable everywhere in the constraint system and then removed from the | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -21,6 +21,7 @@ | |
| #include "mlir/Analysis/Presburger/Simplex.h" | ||
| #include "mlir/Analysis/Presburger/Utils.h" | ||
| #include "llvm/ADT/DenseMap.h" | ||
| #include "llvm/ADT/DenseSet.h" | ||
| #include "llvm/ADT/STLExtras.h" | ||
| #include "llvm/ADT/Sequence.h" | ||
| #include "llvm/ADT/SmallBitVector.h" | ||
|
|
@@ -1723,12 +1724,78 @@ std::optional<DynamicAPInt> IntegerRelation::getConstantBoundOnDimSize( | |
| return minDiff; | ||
| } | ||
|
|
||
| void IntegerRelation::pruneConstraints(unsigned pos) { | ||
| llvm::DenseSet<unsigned> relatedCols({pos}), relatedRows; | ||
|
|
||
| // Early quit if constraints is empty. | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: we usually call this early exit! |
||
| unsigned numConstraints = getNumConstraints(); | ||
| if (numConstraints == 0) | ||
| return; | ||
|
|
||
| llvm::SmallVector<unsigned> rowStack, colStack({pos}); | ||
| // The following code performs a graph traversal, starting from the target | ||
| // variable, to identify all variables(recorded in relatedCols) and | ||
| // constraints(recorded in relatedRows) belonging to the same connected | ||
|
Comment on lines
+1737
to
+1738
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: please leave a space before opening parentheses |
||
| // component. | ||
| while (!rowStack.empty() || !colStack.empty()) { | ||
| if (!rowStack.empty()) { | ||
| unsigned currentRow = rowStack.pop_back_val(); | ||
| // Push all variable that accociated to this constrain to relatedCols | ||
| // and colStack. | ||
| for (uint64_t colIndex = 0; colIndex < getNumVars(); ++colIndex) { | ||
| if (currentRow < getNumInequalities()) { | ||
| if (atIneq(currentRow, colIndex) != 0 && | ||
| relatedCols.insert(colIndex).second) { | ||
| colStack.push_back(colIndex); | ||
| } | ||
| } else { | ||
| if (atEq(currentRow - getNumInequalities(), colIndex) != 0 && | ||
| relatedCols.insert(colIndex).second) { | ||
| colStack.push_back(colIndex); | ||
| } | ||
| } | ||
| } | ||
| } else { | ||
| unsigned currentCol = colStack.pop_back_val(); | ||
| // Push all constrains that accociated to this variable to relatedRows | ||
| // and rowStack. | ||
|
Comment on lines
+1760
to
+1761
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please fix the spelling/grammar on this line (and above). |
||
| for (uint64_t rowIndex = 0; rowIndex < numConstraints; ++rowIndex) { | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You can just use an unsigned for the rowIndex. |
||
| if (rowIndex < getNumInequalities()) { | ||
| if (atIneq(rowIndex, currentCol) != 0 && | ||
| relatedRows.insert(rowIndex).second) { | ||
| rowStack.push_back(rowIndex); | ||
| } | ||
| } else { | ||
| if (atEq(rowIndex - getNumInequalities(), currentCol) != 0 && | ||
| relatedRows.insert(rowIndex).second) { | ||
| rowStack.push_back(rowIndex); | ||
| } | ||
| } | ||
|
Comment on lines
+1763
to
+1773
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There seems to be a lot of duplication here and above. We should pull out a function (maybe just a lambda) called atConstraint or something to avoid the duplicated if/else. |
||
| } | ||
| } | ||
| } | ||
|
|
||
| // Prune all constraints not related to target variable. | ||
| for (int64_t constraintId = numConstraints - 1; constraintId >= 0; | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You can just use an int for the constraintId. |
||
| --constraintId) { | ||
| if (!relatedRows.contains(constraintId)) { | ||
| if (constraintId >= getNumInequalities()) { | ||
|
Comment on lines
+1781
to
+1782
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's preferred to early-exit if constraintId is relevant and then do the removal otherwise, to reduce nesting (https://llvm.org/docs/CodingStandards.html#use-early-exits-and-continue-to-simplify-code) |
||
| removeEquality(constraintId - getNumInequalities()); | ||
| } else { | ||
| removeInequality(constraintId); | ||
| } | ||
| } | ||
| } | ||
| } | ||
|
|
||
| template <bool isLower> | ||
| std::optional<DynamicAPInt> | ||
| IntegerRelation::computeConstantLowerOrUpperBound(unsigned pos) { | ||
| assert(pos < getNumVars() && "invalid position"); | ||
| // Project to 'pos'. | ||
| pruneConstraints(pos); | ||
| projectOut(0, pos); | ||
| pruneConstraints(0); | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you write a comment here explaining why the second pruneConstraints might be helpful? |
||
| projectOut(1, getNumVars() - 1); | ||
| // Check if there's an equality equating the '0'^th variable to a constant. | ||
| int eqRowIdx = findEqualityToConstant(/*pos=*/0, /*symbolic=*/false); | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you write a 1-2 line summary at the top, so that people can quickly understand what the function does, while keeping the details below? You can write that the function removes some constraints that do not impose any bound on the specified variable.