-
Notifications
You must be signed in to change notification settings - Fork 15.3k
Fix bug in visitDivExpr, visitModExpr and visitMulExpr
#145290
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
@llvm/pr-subscribers-mlir-affine @llvm/pr-subscribers-mlir Author: Arnab Dutta (arnab-polymage) ChangesWhenever the result of a div or mod affine expression is a constant expression, place the value in the constant index of the flattened expression instead of adding it as a local expression. Full diff: https://github.com/llvm/llvm-project/pull/145290.diff 1 Files Affected:
diff --git a/mlir/lib/IR/AffineExpr.cpp b/mlir/lib/IR/AffineExpr.cpp
index cc81f9d19aca7..feedef46c66b8 100644
--- a/mlir/lib/IR/AffineExpr.cpp
+++ b/mlir/lib/IR/AffineExpr.cpp
@@ -1177,10 +1177,10 @@ static AffineExpr getSemiAffineExprFromFlatForm(ArrayRef<int64_t> flatExprs,
if (flatExprs[numDims + numSymbols + it.index()] == 0)
continue;
AffineExpr expr = it.value();
- auto binaryExpr = dyn_cast<AffineBinaryOpExpr>(expr);
- if (!binaryExpr)
- continue;
-
+ assert(isa<AffineBinaryOpExpr>(expr) &&
+ "local expression cannot be a dimension, symbol or a constant -- it "
+ "should be a binary op expression");
+ auto binaryExpr = cast<AffineBinaryOpExpr>(expr);
AffineExpr lhs = binaryExpr.getLHS();
AffineExpr rhs = binaryExpr.getRHS();
if (!((isa<AffineDimExpr>(lhs) || isa<AffineSymbolExpr>(lhs)) &&
@@ -1348,6 +1348,11 @@ LogicalResult SimpleAffineExprFlattener::visitModExpr(AffineBinaryOpExpr expr) {
AffineExpr divisorExpr = getAffineExprFromFlatForm(rhs, numDims, numSymbols,
localExprs, context);
AffineExpr modExpr = dividendExpr % divisorExpr;
+ if (auto constModExpr = dyn_cast<AffineConstantExpr>(modExpr)) {
+ std::fill(lhs.begin(), lhs.end(), 0);
+ lhs[getConstantIndex()] = constModExpr.getValue();
+ return success();
+ }
return addLocalVariableSemiAffine(modLhs, rhs, modExpr, lhs, lhs.size());
}
@@ -1482,6 +1487,11 @@ LogicalResult SimpleAffineExprFlattener::visitDivExpr(AffineBinaryOpExpr expr,
AffineExpr b = getAffineExprFromFlatForm(rhs, numDims, numSymbols,
localExprs, context);
AffineExpr divExpr = isCeil ? a.ceilDiv(b) : a.floorDiv(b);
+ if (auto constDivExpr = dyn_cast<AffineConstantExpr>(divExpr)) {
+ std::fill(lhs.begin(), lhs.end(), 0);
+ lhs[getConstantIndex()] = constDivExpr.getValue();
+ return success();
+ }
return addLocalVariableSemiAffine(divLhs, rhs, divExpr, lhs, lhs.size());
}
|
|
@llvm/pr-subscribers-mlir-core Author: Arnab Dutta (arnab-polymage) ChangesWhenever the result of a div or mod affine expression is a constant expression, place the value in the constant index of the flattened expression instead of adding it as a local expression. Full diff: https://github.com/llvm/llvm-project/pull/145290.diff 1 Files Affected:
diff --git a/mlir/lib/IR/AffineExpr.cpp b/mlir/lib/IR/AffineExpr.cpp
index cc81f9d19aca7..feedef46c66b8 100644
--- a/mlir/lib/IR/AffineExpr.cpp
+++ b/mlir/lib/IR/AffineExpr.cpp
@@ -1177,10 +1177,10 @@ static AffineExpr getSemiAffineExprFromFlatForm(ArrayRef<int64_t> flatExprs,
if (flatExprs[numDims + numSymbols + it.index()] == 0)
continue;
AffineExpr expr = it.value();
- auto binaryExpr = dyn_cast<AffineBinaryOpExpr>(expr);
- if (!binaryExpr)
- continue;
-
+ assert(isa<AffineBinaryOpExpr>(expr) &&
+ "local expression cannot be a dimension, symbol or a constant -- it "
+ "should be a binary op expression");
+ auto binaryExpr = cast<AffineBinaryOpExpr>(expr);
AffineExpr lhs = binaryExpr.getLHS();
AffineExpr rhs = binaryExpr.getRHS();
if (!((isa<AffineDimExpr>(lhs) || isa<AffineSymbolExpr>(lhs)) &&
@@ -1348,6 +1348,11 @@ LogicalResult SimpleAffineExprFlattener::visitModExpr(AffineBinaryOpExpr expr) {
AffineExpr divisorExpr = getAffineExprFromFlatForm(rhs, numDims, numSymbols,
localExprs, context);
AffineExpr modExpr = dividendExpr % divisorExpr;
+ if (auto constModExpr = dyn_cast<AffineConstantExpr>(modExpr)) {
+ std::fill(lhs.begin(), lhs.end(), 0);
+ lhs[getConstantIndex()] = constModExpr.getValue();
+ return success();
+ }
return addLocalVariableSemiAffine(modLhs, rhs, modExpr, lhs, lhs.size());
}
@@ -1482,6 +1487,11 @@ LogicalResult SimpleAffineExprFlattener::visitDivExpr(AffineBinaryOpExpr expr,
AffineExpr b = getAffineExprFromFlatForm(rhs, numDims, numSymbols,
localExprs, context);
AffineExpr divExpr = isCeil ? a.ceilDiv(b) : a.floorDiv(b);
+ if (auto constDivExpr = dyn_cast<AffineConstantExpr>(divExpr)) {
+ std::fill(lhs.begin(), lhs.end(), 0);
+ lhs[getConstantIndex()] = constDivExpr.getValue();
+ return success();
+ }
return addLocalVariableSemiAffine(divLhs, rhs, divExpr, lhs, lhs.size());
}
|
f51da54 to
7c184d5
Compare
bondhugula
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The changes aren't complete or sound.
7c184d5 to
75ab97d
Compare
bondhugula
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing test cases for the scenarios exercised. Changes are completely untested!
Which constraint? |
75ab97d to
17ce1f0
Compare
visitDivExpr and visitModExprvisitDivExpr, visitModExpr and visitMulExpr
17ce1f0 to
5fc2915
Compare
5fc2915 to
7d262f5
Compare
7d262f5 to
c2493e6
Compare
c2493e6 to
c140442
Compare
c140442 to
f7715ad
Compare
mlir/lib/IR/AffineExpr.cpp
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is completely misnamed. Is this specific to semi-affine exprs? In fact, addLocalIdSemiAffine in any of its overrides doesn't even use lhs and rhs. So all of these arguments aren't making sense to me.
LogicalResult SimpleAffineExprFlattener::addLocalIdSemiAffine(
ArrayRef<int64_t> lhs, ArrayRef<int64_t> rhs, AffineExpr localExpr) {
for (SmallVector<int64_t, 8> &subExpr : operandExprStack)
subExpr.insert(subExpr.begin() + getLocalVarStartIndex() + numLocals, 0);
localExprs.push_back(localExpr);
++numLocals;
// lhs and rhs are not used here; an override of this method uses them.
return success();
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The comment inside the function you have pasted explains this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, I don't see the override using them either.
Whenever the result of a div or mod affine expression is a constant expression, place the value in the constant index of the flattened expression instead of adding it as a local expression.
f7715ad to
7813964
Compare
mlir/lib/IR/AffineExpr.cpp
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, I don't see the override using them either.
| result[getSymbolStartIndex() + symExpr.getPosition()] = 1; | ||
| return success(); | ||
| } | ||
| return addLocalVariableSemiAffine(lhs, rhs, expr, result, result.size()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How do we know that expr is semi-affine at this stage? It can just be a purely affine binary expression. This method is confusing without any further comments. At this point, all you know is that expr is an affine binary expression and you might as well cast it to that and send it to make the signature of addLocal... less confusing.
Whenever the result of a div, mod or mul affine expression is a constant expression, place the value in the constant index of the flattened expression instead of adding it as a local expression.