diff --git a/mlir/include/mlir/Analysis/Presburger/Utils.h b/mlir/include/mlir/Analysis/Presburger/Utils.h index 0e6d18279d67e..6d06daa91d376 100644 --- a/mlir/include/mlir/Analysis/Presburger/Utils.h +++ b/mlir/include/mlir/Analysis/Presburger/Utils.h @@ -96,7 +96,8 @@ enum class ReprKind { Inequality, Equality, None }; /// set to None. struct MaybeLocalRepr { ReprKind kind = ReprKind::None; - explicit operator bool() const { return kind != ReprKind::None; } + explicit operator bool() const { return hasRepr(); } + bool hasRepr() const { return kind != ReprKind::None; } union { unsigned equalityIdx; struct { diff --git a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp index 74cdf567c0e56..d15efaa1b86f9 100644 --- a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp +++ b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp @@ -225,35 +225,6 @@ PresburgerRelation IntegerRelation::computeReprWithOnlyDivLocals() const { if (getNumLocalVars() == 0) return PresburgerRelation(*this); - // Move all the non-div locals to the end, as the current API to - // SymbolicLexOpt requires these to form a contiguous range. - // - // Take a copy so we can perform mutations. - IntegerRelation copy = *this; - std::vector reprs(getNumLocalVars()); - copy.getLocalReprs(&reprs); - - // Iterate through all the locals. The last `numNonDivLocals` are the locals - // that have been scanned already and do not have division representations. - unsigned numNonDivLocals = 0; - unsigned offset = copy.getVarKindOffset(VarKind::Local); - for (unsigned i = 0, e = copy.getNumLocalVars(); i < e - numNonDivLocals;) { - if (!reprs[i]) { - // Whenever we come across a local that does not have a division - // representation, we swap it to the `numNonDivLocals`-th last position - // and increment `numNonDivLocal`s. `reprs` also needs to be swapped. - copy.swapVar(offset + i, offset + e - numNonDivLocals - 1); - std::swap(reprs[i], reprs[e - numNonDivLocals - 1]); - ++numNonDivLocals; - continue; - } - ++i; - } - - // If there are no non-div locals, we're done. - if (numNonDivLocals == 0) - return PresburgerRelation(*this); - // We computeSymbolicIntegerLexMin by considering the non-div locals as // "non-symbols" and considering everything else as "symbols". This will // compute a function mapping assignments to "symbols" to the @@ -265,10 +236,29 @@ PresburgerRelation IntegerRelation::computeReprWithOnlyDivLocals() const { // exists, which is the union of the domain of the returned lexmin function // and the returned set of assignments to the "symbols" that makes the lexmin // unbounded. + + // Construct a BitVector that identifies which variables we would like to + // treat as symbols. We want to treat all variables as "symbols" except for + // the locals that don't have a division representation. + std::vector reprs(getNumLocalVars()); + this->getLocalReprs(&reprs); + + llvm::SmallBitVector isSymbol(getNumVars(), true); + unsigned offset = getVarKindOffset(VarKind::Local); + for (unsigned i = 0, e = getNumLocalVars(); i < e; ++i) { + isSymbol[offset + i] = reprs[i].hasRepr(); + } + + // If there are no non-div locals (non-symbols), we're done. + if (isSymbol.all()) + return PresburgerRelation(*this); + SymbolicLexOpt lexminResult = - SymbolicLexSimplex(copy, /*symbolOffset*/ 0, + SymbolicLexSimplex(/*constraints=*/*this, + /*symbolDomain=*/ IntegerPolyhedron(PresburgerSpace::getSetSpace( - /*numDims=*/copy.getNumVars() - numNonDivLocals))) + /*numDims=*/isSymbol.count())), + isSymbol) .computeSymbolicIntegerLexMin(); PresburgerRelation result = lexminResult.lexopt.getDomain().unionSet(lexminResult.unboundedDomain); diff --git a/mlir/lib/Analysis/Presburger/Simplex.cpp b/mlir/lib/Analysis/Presburger/Simplex.cpp index 4ffa2d546af4d..3b8b06dce8f25 100644 --- a/mlir/lib/Analysis/Presburger/Simplex.cpp +++ b/mlir/lib/Analysis/Presburger/Simplex.cpp @@ -64,13 +64,36 @@ SimplexBase::SimplexBase(unsigned nVar, bool mustUseBigM, const llvm::SmallBitVector &isSymbol) : SimplexBase(nVar, mustUseBigM) { assert(isSymbol.size() == nVar && "invalid bitmask!"); - // Invariant: nSymbol is the number of symbols that have been marked - // already and these occupy the columns - // [getNumFixedCols(), getNumFixedCols() + nSymbol). - for (unsigned symbolIdx : isSymbol.set_bits()) { - var[symbolIdx].isSymbol = true; - swapColumns(var[symbolIdx].pos, getNumFixedCols() + nSymbol); - ++nSymbol; + // Iterate through all the variables. Move symbols to the left and non-symbols + // to the right while preserving relative ordering. + for (unsigned i = 0; i < nVar; ++i) { + if (isSymbol[i]) { + + // Move the column from its current position to the end of + // the symbols segment and update the position metadata for each column + // unknown (which should all be vars at construction time). The + // segment of non-symbol up until the current variable's column are + // shifted to the right and the current column is then moved before the + // right-shifted segment. + tableau.moveColumns(var[i].pos, 1, getNumFixedCols() + nSymbol); + + // Update the column position metadata for the unknowns associated with + // the right-shifted columns. + for (unsigned col = getNumFixedCols() + nSymbol; col < var[i].pos; ++col) + unknownFromColumn(col).pos = col + 1; + + // Perform the equivalent rearrangement on the col-to-unknown + // mapping. + std::rotate(colUnknown.begin() + getNumFixedCols() + nSymbol, + colUnknown.begin() + var[i].pos, + colUnknown.begin() + var[i].pos + 1); + + // Update the column mapping for the current variable. + var[i].pos = getNumFixedCols() + nSymbol; + var[i].isSymbol = true; + + ++nSymbol; + } } } diff --git a/mlir/unittests/Analysis/Presburger/PresburgerSetTest.cpp b/mlir/unittests/Analysis/Presburger/PresburgerSetTest.cpp index 8e31a8bb2030b..7561d2044bd77 100644 --- a/mlir/unittests/Analysis/Presburger/PresburgerSetTest.cpp +++ b/mlir/unittests/Analysis/Presburger/PresburgerSetTest.cpp @@ -855,6 +855,20 @@ TEST(SetTest, computeReprWithOnlyDivLocals) { PresburgerSet(parseIntegerPolyhedron( {"(x) : (x - 3*(x floordiv 3) == 0)"})), /*numToProject=*/2); + + testComputeRepr( + parseIntegerPolyhedron("(e, a, b, c)[] : (" + "a >= 0," + "b >= 0," + "c >= 0," + "e >= 0," + "15 - a >= 0," + "7 - b >= 0," + "5 - c >= 0," + "e - a * 192 - c * 32 - b * 4 >= 0," + "3 - e + a * 192 + c * 32 + b * 4 >= 0)"), + parsePresburgerSet({"(i) : (i >= 0, i <= 3071)"}), + /*numToProject=*/3); } TEST(SetTest, subtractOutputSizeRegression) {