Skip to content

Commit b045729

Browse files
authored
[mlir][presburger] add functionality to compute local mod in IntegerRelation (#153614)
Similar to `IntegerRelation::addLocalFloorDiv`, this adds a utility `IntegerRelation::addLocalModulo` that adds and returns a local variable that is the modulus of an affine function of the variables modulo some constant modulus. The function returns the absolute index of the new var in the relation. This is computed by first finding the floordiv of `exprs // modulus = q` and then computing the remainder `result = exprs - q * modulus`. Signed-off-by: Asra Ali <[email protected]>
1 parent a8d2568 commit b045729

File tree

3 files changed

+48
-0
lines changed

3 files changed

+48
-0
lines changed

mlir/include/mlir/Analysis/Presburger/IntegerRelation.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -485,6 +485,22 @@ class IntegerRelation {
485485
addLocalFloorDiv(getDynamicAPIntVec(dividend), DynamicAPInt(divisor));
486486
}
487487

488+
/// Adds a new local variable as the modulus of an affine function of other
489+
/// variables, the coefficients of which are provided in `exprs`. The modulus
490+
/// is with respect to a positive constant `modulus`. The function returns the
491+
/// absolute index of the new local variable representing the result of the
492+
/// modulus operation. Two new local variables are added to the system, one
493+
/// representing the floor div with respect to the modulus and one
494+
/// representing the mod. Three constraints are added to the system to capture
495+
/// the equivalance. The first two are required to compute the result of the
496+
/// floor division `q`, and the third computes the equality relation:
497+
/// result = exprs - modulus * q.
498+
unsigned addLocalModulo(ArrayRef<DynamicAPInt> exprs,
499+
const DynamicAPInt &modulus);
500+
unsigned addLocalModulo(ArrayRef<int64_t> exprs, int64_t modulus) {
501+
return addLocalModulo(getDynamicAPIntVec(exprs), DynamicAPInt(modulus));
502+
}
503+
488504
/// Projects out (aka eliminates) `num` variables starting at position
489505
/// `pos`. The resulting constraint system is the shadow along the dimensions
490506
/// that still exist. This method may not always be integer exact.

mlir/lib/Analysis/Presburger/IntegerRelation.cpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1515,6 +1515,27 @@ void IntegerRelation::addLocalFloorDiv(ArrayRef<DynamicAPInt> dividend,
15151515
getDivUpperBound(dividendCopy, divisor, dividendCopy.size() - 2));
15161516
}
15171517

1518+
unsigned IntegerRelation::addLocalModulo(ArrayRef<DynamicAPInt> exprs,
1519+
const DynamicAPInt &modulus) {
1520+
assert(exprs.size() == getNumCols() && "incorrect exprs size");
1521+
assert(modulus > 0 && "positive modulus expected");
1522+
1523+
/// Add a local variable for q = expr floordiv modulus
1524+
addLocalFloorDiv(exprs, modulus);
1525+
1526+
/// Add a local var to represent the result
1527+
auto resultIndex = appendVar(VarKind::Local);
1528+
1529+
SmallVector<DynamicAPInt, 8> exprsCopy(exprs);
1530+
/// Insert the two new locals before the constant
1531+
/// Add locals that correspond to `q` and `result` to compute
1532+
/// 0 = (expr - modulus * q) - result
1533+
exprsCopy.insert(exprsCopy.end() - 1,
1534+
{DynamicAPInt(-modulus), DynamicAPInt(-1)});
1535+
addEquality(exprsCopy);
1536+
return resultIndex;
1537+
}
1538+
15181539
int IntegerRelation::findEqualityToConstant(unsigned pos, bool symbolic) const {
15191540
assert(pos < getNumVars() && "invalid position");
15201541
for (unsigned r = 0, e = getNumEqualities(); r < e; r++) {

mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -714,3 +714,14 @@ TEST(IntegerRelationTest, getVarKindRange) {
714714
}
715715
EXPECT_THAT(actual, ElementsAre(2, 3, 4));
716716
}
717+
718+
TEST(IntegerRelationTest, addLocalModulo) {
719+
IntegerRelation rel = parseRelationFromSet("(x) : (x >= 0, 100 - x >= 0)", 1);
720+
unsigned result = rel.addLocalModulo({1, 0}, 32); // x % 32
721+
rel.convertVarKind(VarKind::Local,
722+
result - rel.getVarKindOffset(VarKind::Local),
723+
rel.getNumVarKind(VarKind::Local), VarKind::Range);
724+
for (unsigned x = 0; x <= 100; ++x) {
725+
EXPECT_TRUE(rel.containsPointNoLocal({x, x % 32}));
726+
}
727+
}

0 commit comments

Comments
 (0)