Skip to content

Commit f736501

Browse files
Enhance unionBoundingBox utility
Enhance `unionBoundingBox` utility to work with input constraints having local variables.
1 parent 11766a4 commit f736501

File tree

5 files changed

+47
-48
lines changed

5 files changed

+47
-48
lines changed

mlir/include/mlir/Analysis/FlatLinearValueConstraints.h

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -474,11 +474,10 @@ class FlatLinearValueConstraints : public FlatLinearConstraints {
474474
bool areVarsAlignedWithOther(const FlatLinearConstraints &other);
475475

476476
/// Updates the constraints to be the smallest bounding (enclosing) box that
477-
/// contains the points of `this` set and that of `other`, with the symbols
478-
/// being treated specially. For each of the dimensions, the min of the lower
479-
/// bounds (symbolic) and the max of the upper bounds (symbolic) is computed
480-
/// to determine such a bounding box. `other` is expected to have the same
481-
/// dimensional variables as this constraint system (in the same order).
477+
/// contains the points of `this` set and that of `other`. For each of the
478+
/// dimensions, the min of the lower bounds and the max of the upper bounds is
479+
/// computed to determine such a bounding box. `other` is expected to have the
480+
/// same dimensional variables as this constraint system (in the same order).
482481
///
483482
/// E.g.:
484483
/// 1) this = {0 <= d0 <= 127},

mlir/include/mlir/Analysis/Presburger/IntegerRelation.h

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -489,11 +489,10 @@ class IntegerRelation {
489489
void constantFoldVarRange(unsigned pos, unsigned num);
490490

491491
/// Updates the constraints to be the smallest bounding (enclosing) box that
492-
/// contains the points of `this` set and that of `other`, with the symbols
493-
/// being treated specially. For each of the dimensions, the min of the lower
494-
/// bounds (symbolic) and the max of the upper bounds (symbolic) is computed
495-
/// to determine such a bounding box. `other` is expected to have the same
496-
/// dimensional variables as this constraint system (in the same order).
492+
/// contains the points of `this` set and that of `other`. For each of the
493+
/// dimensions, the min of the lower bounds and the max of the upper bounds is
494+
/// computed to determine such a bounding box. `other` is expected to have the
495+
/// same dimensional variables as this constraint system (in the same order).
497496
///
498497
/// E.g.:
499498
/// 1) this = {0 <= d0 <= 127},
@@ -512,14 +511,13 @@ class IntegerRelation {
512511
/// than or equal to 'exclusive upper bound' - 'lower bound' of the
513512
/// variable. This constant bound is guaranteed to be non-negative. Returns
514513
/// std::nullopt if it's not a constant. This method employs trivial (low
515-
/// complexity / cost) checks and detection. Symbolic variables are treated
516-
/// specially, i.e., it looks for constant differences between affine
517-
/// expressions involving only the symbolic variables. `lb` and `ub` (along
518-
/// with the `boundFloorDivisor`) are set to represent the lower and upper
519-
/// bound associated with the constant difference: `lb`, `ub` have the
520-
/// coefficients, and `boundFloorDivisor`, their divisor. `minLbPos` and
521-
/// `minUbPos` if non-null are set to the position of the constant lower bound
522-
/// and upper bound respectively (to the same if they are from an
514+
/// complexity / cost) checks and detection. It looks for constant differences
515+
/// between affine expressions involving symbolic and local variables. `lb`
516+
/// and `ub` (along with the `boundFloorDivisor`) are set to represent the
517+
/// lower and upper bound associated with the constant difference: `lb`, `ub`
518+
/// have the coefficients, and `boundFloorDivisor`, their divisor. `minLbPos`
519+
/// and `minUbPos` if non-null are set to the position of the constant lower
520+
/// bound and upper bound respectively (to the same if they are from an
523521
/// equality). Ex: if the lower bound is [(s0 + s2 - 1) floordiv 32] for a
524522
/// system with three symbolic variables, *lb = [1, 0, 1], lbDivisor = 32. See
525523
/// comments at function definition for examples.

mlir/lib/Analysis/FlatLinearValueConstraints.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1303,8 +1303,6 @@ LogicalResult FlatLinearValueConstraints::unionBoundingBox(
13031303
otherMaybeValues.begin(),
13041304
otherMaybeValues.begin() + getNumDimVars()) &&
13051305
"dim values mismatch");
1306-
assert(otherCst.getNumLocalVars() == 0 && "local vars not supported here");
1307-
assert(getNumLocalVars() == 0 && "local vars not supported yet here");
13081306

13091307
// Align `other` to this.
13101308
if (!areVarsAligned(*this, otherCst)) {

mlir/lib/Analysis/Presburger/IntegerRelation.cpp

Lines changed: 18 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1578,13 +1578,11 @@ void IntegerRelation::constantFoldVarRange(unsigned pos, unsigned num) {
15781578

15791579
/// Returns a non-negative constant bound on the extent (upper bound - lower
15801580
/// bound) of the specified variable if it is found to be a constant; returns
1581-
/// std::nullopt if it's not a constant. This methods treats symbolic variables
1582-
/// specially, i.e., it looks for constant differences between affine
1583-
/// expressions involving only the symbolic variables. See comments at function
1584-
/// definition for example. 'lb', if provided, is set to the lower bound
1585-
/// associated with the constant difference. Note that 'lb' is purely symbolic
1586-
/// and thus will contain the coefficients of the symbolic variables and the
1587-
/// constant coefficient.
1581+
/// std::nullopt if it's not a constant. This methods looks for constant
1582+
/// differences between affine expressions. See comments at function definition
1583+
/// for example. 'lb', if provided, is set to the lower bound associated with
1584+
/// the constant difference. `lb' will contain the coefficients of the symbolic
1585+
/// variables, local variables and the constant coefficient.
15881586
// Egs: 0 <= i <= 15, return 16.
15891587
// s0 + 2 <= i <= s0 + 17, returns 16. (s0 has to be a symbol)
15901588
// s0 + s1 + 16 <= d0 <= s0 + s1 + 31, returns 16.
@@ -1600,22 +1598,15 @@ std::optional<DynamicAPInt> IntegerRelation::getConstantBoundOnDimSize(
16001598
// of the symbolic variables (+ constant).
16011599
int eqPos = findEqualityToConstant(*this, pos, /*symbolic=*/true);
16021600
if (eqPos != -1) {
1603-
auto eq = getEquality(eqPos);
1604-
// If the equality involves a local var, punt for now.
1605-
// TODO: this can be handled in the future by using the explicit
1606-
// representation of the local vars.
1607-
if (!std::all_of(eq.begin() + getNumDimAndSymbolVars(), eq.end() - 1,
1608-
[](const DynamicAPInt &coeff) { return coeff == 0; }))
1609-
return std::nullopt;
1610-
16111601
// This variable can only take a single value.
16121602
if (lb) {
16131603
// Set lb to that symbolic value.
1614-
lb->resize(getNumSymbolVars() + 1);
1604+
lb->resize(getNumSymbolVars() + getNumLocalVars() + 1);
16151605
if (ub)
1616-
ub->resize(getNumSymbolVars() + 1);
1617-
for (unsigned c = 0, f = getNumSymbolVars() + 1; c < f; c++) {
1618-
DynamicAPInt v = atEq(eqPos, pos);
1606+
ub->resize(getNumSymbolVars() + getNumLocalVars() + 1);
1607+
for (unsigned c = 0, f = getNumSymbolVars() + getNumLocalVars() + 1;
1608+
c < f; c++) {
1609+
MPInt v = atEq(eqPos, pos);
16191610
// atEq(eqRow, pos) is either -1 or 1.
16201611
assert(v * v == 1);
16211612
(*lb)[c] = v < 0 ? atEq(eqPos, getNumDimVars() + c) / -v
@@ -1687,27 +1678,30 @@ std::optional<DynamicAPInt> IntegerRelation::getConstantBoundOnDimSize(
16871678
}
16881679
if (lb && minDiff) {
16891680
// Set lb to the symbolic lower bound.
1690-
lb->resize(getNumSymbolVars() + 1);
1681+
lb->resize(getNumSymbolVars() + getNumLocalVars() + 1);
16911682
if (ub)
1692-
ub->resize(getNumSymbolVars() + 1);
1683+
ub->resize(getNumSymbolVars() + getNumLocalVars() + 1);
16931684
// The lower bound is the ceildiv of the lb constraint over the coefficient
16941685
// of the variable at 'pos'. We express the ceildiv equivalently as a floor
16951686
// for uniformity. For eg., if the lower bound constraint was: 32*d0 - N +
16961687
// 31 >= 0, the lower bound for d0 is ceil(N - 31, 32), i.e., floor(N, 32).
16971688
*boundFloorDivisor = atIneq(minLbPosition, pos);
16981689
assert(*boundFloorDivisor == -atIneq(minUbPosition, pos));
1699-
for (unsigned c = 0, e = getNumSymbolVars() + 1; c < e; c++) {
1690+
for (unsigned c = 0, e = getNumSymbolVars() + getNumLocalVars() + 1; c < e;
1691+
c++) {
17001692
(*lb)[c] = -atIneq(minLbPosition, getNumDimVars() + c);
17011693
}
17021694
if (ub) {
1703-
for (unsigned c = 0, e = getNumSymbolVars() + 1; c < e; c++)
1695+
for (unsigned c = 0, e = getNumSymbolVars() + getNumLocalVars() + 1;
1696+
c < e; c++)
17041697
(*ub)[c] = atIneq(minUbPosition, getNumDimVars() + c);
17051698
}
17061699
// The lower bound leads to a ceildiv while the upper bound is a floordiv
17071700
// whenever the coefficient at pos != 1. ceildiv (val / d) = floordiv (val +
17081701
// d - 1 / d); hence, the addition of 'atIneq(minLbPosition, pos) - 1' to
17091702
// the constant term for the lower bound.
1710-
(*lb)[getNumSymbolVars()] += atIneq(minLbPosition, pos) - 1;
1703+
(*lb)[getNumSymbolVars() + getNumLocalVars()] +=
1704+
atIneq(minLbPosition, pos) - 1;
17111705
}
17121706
if (minLbPos)
17131707
*minLbPos = minLbPosition;
@@ -2180,8 +2174,6 @@ static void getCommonConstraints(const IntegerRelation &a,
21802174
LogicalResult
21812175
IntegerRelation::unionBoundingBox(const IntegerRelation &otherCst) {
21822176
assert(space.isEqual(otherCst.getSpace()) && "Spaces should match.");
2183-
assert(getNumLocalVars() == 0 && "local ids not supported yet here");
2184-
21852177
// Get the constraints common to both systems; these will be added as is to
21862178
// the union.
21872179
IntegerRelation commonCst(PresburgerSpace::getRelationSpace());
@@ -2211,11 +2203,9 @@ IntegerRelation::unionBoundingBox(const IntegerRelation &otherCst) {
22112203
auto otherExtent = otherCst.getConstantBoundOnDimSize(
22122204
d, &otherLb, &otherLbFloorDivisor, &otherUb);
22132205
if (!otherExtent.has_value() || lbFloorDivisor != otherLbFloorDivisor)
2214-
// TODO: symbolic extents when necessary.
22152206
return failure();
22162207

22172208
assert(lbFloorDivisor > 0 && "divisor always expected to be positive");
2218-
22192209
auto res = compareBounds(lb, otherLb);
22202210
// Identify min.
22212211
if (res == BoundCmpResult::Less || res == BoundCmpResult::Equal) {

mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -608,3 +608,17 @@ TEST(IntegerRelationTest, convertVarKindToLocal) {
608608
EXPECT_EQ(space.getId(VarKind::Symbol, 0), Identifier(&identifiers[3]));
609609
EXPECT_EQ(space.getId(VarKind::Symbol, 1), Identifier(&identifiers[4]));
610610
}
611+
612+
// Test union of two integer relations if they have local variable(s).
613+
TEST(IntegerRelationTest, unionBoundingBox) {
614+
IntegerRelation relA = parseRelationFromSet(
615+
"(x, y, z)[N, M]: (y floordiv 2 - N + x == 0, z floordiv 5 - N - x"
616+
">= 0, x + y + z floordiv 6 == 0)",
617+
1);
618+
IntegerRelation relB = parseRelationFromSet(
619+
"(x, y, z)[N, M]: (y floordiv 2 - N + x == 0, z floordiv 5 - M - x"
620+
">= 0, x + y + z floordiv 7 == 0)",
621+
1);
622+
assert(relA.getNumLocalVars() > 0);
623+
EXPECT_TRUE(relA.unionBoundingBox(relB).succeeded());
624+
}

0 commit comments

Comments
 (0)