Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions mlir/include/mlir/Analysis/Presburger/IntegerRelation.h
Original file line number Diff line number Diff line change
Expand Up @@ -707,6 +707,19 @@ class IntegerRelation {
/// this for uniformity with `applyDomain`.
void applyRange(const IntegerRelation &rel);

/// Let the relation `this` be R1, and the relation `rel` be R2. Requires
/// R1 and R2 to have the same domain.
///
/// Let R3 be the rangeProduct of R1 and R2. Then x R3 (y, z) iff
/// (x R1 y and x R2 z).
///
/// Example:
///
/// R1: (i, j) -> k : f(i, j, k) = 0
/// R2: (i, j) -> l : g(i, j, l) = 0
/// R1.rangeProduct(R2): (i, j) -> (k, l) : f(i, j, k) = 0 and g(i, j, l) = 0
IntegerRelation rangeProduct(const IntegerRelation &rel);

/// Given a relation `other: (A -> B)`, this operation merges the symbol and
/// local variables and then takes the composition of `other` on `this: (B ->
/// C)`. The resulting relation represents tuples of the form: `A -> C`.
Expand Down
38 changes: 38 additions & 0 deletions mlir/lib/Analysis/Presburger/IntegerRelation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2481,6 +2481,44 @@ void IntegerRelation::applyDomain(const IntegerRelation &rel) {

void IntegerRelation::applyRange(const IntegerRelation &rel) { compose(rel); }

IntegerRelation IntegerRelation::rangeProduct(const IntegerRelation &rel) {
/// R1: (i, j) -> k : f(i, j, k) = 0
/// R2: (i, j) -> l : g(i, j, l) = 0
/// R1.rangeProduct(R2): (i, j) -> (k, l) : f(i, j, k) = 0 and g(i, j, l) = 0
assert(getNumDomainVars() == rel.getNumDomainVars() &&
"Range product is only defined for relations with equal domains");

// explicit copy of `this`
IntegerRelation result = *this;
unsigned relRangeVarStart = rel.getVarKindOffset(VarKind::Range);
unsigned numThisRangeVars = getNumRangeVars();
unsigned numNewSymbolVars = result.getNumSymbolVars() - getNumSymbolVars();

result.appendVar(VarKind::Range, rel.getNumRangeVars());

// Copy each equality from `rel` and update the copy to account for range
// variables from `this`. The `rel` equality is a list of coefficients of the
// variables from `rel`, and so the range variables need to be shifted right
// by the number of `this` range variables and symbols.
for (unsigned i = 0; i < rel.getNumEqualities(); ++i) {
SmallVector<DynamicAPInt> copy =
SmallVector<DynamicAPInt>(rel.getEquality(i));
copy.insert(copy.begin() + relRangeVarStart,
numThisRangeVars + numNewSymbolVars, DynamicAPInt(0));
result.addEquality(copy);
}

for (unsigned i = 0; i < rel.getNumInequalities(); ++i) {
SmallVector<DynamicAPInt> copy =
SmallVector<DynamicAPInt>(rel.getInequality(i));
copy.insert(copy.begin() + relRangeVarStart,
numThisRangeVars + numNewSymbolVars, DynamicAPInt(0));
result.addInequality(copy);
}

return result;
}

void IntegerRelation::printSpace(raw_ostream &os) const {
space.print(os);
os << getNumConstraints() << " constraints\n";
Expand Down
94 changes: 94 additions & 0 deletions mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -608,3 +608,97 @@ TEST(IntegerRelationTest, convertVarKindToLocal) {
EXPECT_EQ(space.getId(VarKind::Symbol, 0), Identifier(&identifiers[3]));
EXPECT_EQ(space.getId(VarKind::Symbol, 1), Identifier(&identifiers[4]));
}

TEST(IntegerRelationTest, rangeProduct) {
IntegerRelation r1 = parseRelationFromSet(
"(i, j, k) : (2*i + 3*k == 0, i >= 0, j >= 0, k >= 0)", 2);
IntegerRelation r2 = parseRelationFromSet(
"(i, j, l) : (4*i + 6*j + 9*l == 0, i >= 0, j >= 0, l >= 0)", 2);

IntegerRelation rangeProd = r1.rangeProduct(r2);
IntegerRelation expected =
parseRelationFromSet("(i, j, k, l) : (2*i + 3*k == 0, 4*i + 6*j + 9*l == "
"0, i >= 0, j >= 0, k >= 0, l >= 0)",
2);

EXPECT_TRUE(expected.isEqual(rangeProd));
}

TEST(IntegerRelationTest, rangeProductMultdimRange) {
IntegerRelation r1 =
parseRelationFromSet("(i, k) : (2*i + 3*k == 0, i >= 0, k >= 0)", 1);
IntegerRelation r2 = parseRelationFromSet(
"(i, l, m) : (4*i + 6*m + 9*l == 0, i >= 0, l >= 0, m >= 0)", 1);

IntegerRelation rangeProd = r1.rangeProduct(r2);
IntegerRelation expected =
parseRelationFromSet("(i, k, l, m) : (2*i + 3*k == 0, 4*i + 6*m + 9*l == "
"0, i >= 0, k >= 0, l >= 0, m >= 0)",
1);

EXPECT_TRUE(expected.isEqual(rangeProd));
}

TEST(IntegerRelationTest, rangeProductMultdimRangeSwapped) {
IntegerRelation r1 = parseRelationFromSet(
"(i, l, m) : (4*i + 6*m + 9*l == 0, i >= 0, l >= 0, m >= 0)", 1);
IntegerRelation r2 =
parseRelationFromSet("(i, k) : (2*i + 3*k == 0, i >= 0, k >= 0)", 1);

IntegerRelation rangeProd = r1.rangeProduct(r2);
IntegerRelation expected =
parseRelationFromSet("(i, l, m, k) : (2*i + 3*k == 0, 4*i + 6*m + 9*l == "
"0, i >= 0, k >= 0, l >= 0, m >= 0)",
1);

EXPECT_TRUE(expected.isEqual(rangeProd));
}

TEST(IntegerRelationTest, rangeProductEmptyDomain) {
IntegerRelation r1 =
parseRelationFromSet("(i, j) : (4*i + 9*j == 0, i >= 0, j >= 0)", 0);
IntegerRelation r2 =
parseRelationFromSet("(k, l) : (2*k + 3*l == 0, k >= 0, l >= 0)", 0);
IntegerRelation rangeProd = r1.rangeProduct(r2);
IntegerRelation expected =
parseRelationFromSet("(i, j, k, l) : (2*k + 3*l == 0, 4*i + 9*j == "
"0, i >= 0, j >= 0, k >= 0, l >= 0)",
0);
EXPECT_TRUE(expected.isEqual(rangeProd));
}

TEST(IntegerRelationTest, rangeProductEmptyRange) {
IntegerRelation r1 =
parseRelationFromSet("(i, j) : (4*i + 9*j == 0, i >= 0, j >= 0)", 2);
IntegerRelation r2 =
parseRelationFromSet("(i, j) : (2*i + 3*j == 0, i >= 0, j >= 0)", 2);
IntegerRelation rangeProd = r1.rangeProduct(r2);
IntegerRelation expected =
parseRelationFromSet("(i, j) : (2*i + 3*j == 0, 4*i + 9*j == "
"0, i >= 0, j >= 0)",
2);
EXPECT_TRUE(expected.isEqual(rangeProd));
}

TEST(IntegerRelationTest, rangeProductEmptyDomainAndRange) {
IntegerRelation r1 = parseRelationFromSet("() : ()", 0);
IntegerRelation r2 = parseRelationFromSet("() : ()", 0);
IntegerRelation rangeProd = r1.rangeProduct(r2);
IntegerRelation expected = parseRelationFromSet("() : ()", 0);
EXPECT_TRUE(expected.isEqual(rangeProd));
}

TEST(IntegerRelationTest, rangeProductSymbols) {
IntegerRelation r1 = parseRelationFromSet(
"(i, j)[s] : (2*i + 3*j + s == 0, i >= 0, j >= 0)", 1);
IntegerRelation r2 = parseRelationFromSet(
"(i, l)[s] : (3*i + 4*l + s == 0, i >= 0, l >= 0)", 1);

IntegerRelation rangeProd = r1.rangeProduct(r2);
IntegerRelation expected = parseRelationFromSet(
"(i, j, l)[s] : (2*i + 3*j + s == 0, 3*i + 4*l + s == "
"0, i >= 0, j >= 0, l >= 0)",
1);

EXPECT_TRUE(expected.isEqual(rangeProd));
}