-
Notifications
You must be signed in to change notification settings - Fork 14.9k
Lower affine modulo by powers of two using bitwise AND #146311
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Thank you for submitting a Pull Request (PR) to the LLVM Project! This PR will be automatically labeled and the relevant teams will be notified. If you wish to, you can add reviewers by using the "Reviewers" section on this page. If this is not working for you, it is probably because you do not have write permissions for the repository. In which case you can instead tag reviewers by name in a comment by using If you have received no comments on your PR for a week, you can request a review by "ping"ing the PR by adding a comment “Ping”. The common courtesy "ping" rate is once a week. Please remember that you are asking for valuable time from other developers. If you have further questions, they may be answered by the LLVM GitHub User Guide. You can also ask questions in a comment on this PR, on the LLVM Discord or on the forums. |
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-affine Author: Yuxi Sun (sherylll) ChangesThis patch adds a special-case optimization in the affine-to-standard lowering pass to replace modulo operations by constant powers of two with a single bitwise AND operation. This reduces instruction count and improves performance for common cases like Full diff: https://github.com/llvm/llvm-project/pull/146311.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Affine/Utils/Utils.cpp b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
index 66b3f2a4f93a5..de9c7874767e4 100644
--- a/mlir/lib/Dialect/Affine/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
@@ -80,12 +80,24 @@ class AffineApplyExpander
/// let remainder = srem a, b;
/// negative = a < 0 in
/// select negative, remainder + b, remainder.
+ ///
+ /// Special case for power of 2: use bitwise AND (x & (n-1)) for non-negative x.
Value visitModExpr(AffineBinaryOpExpr expr) {
if (auto rhsConst = dyn_cast<AffineConstantExpr>(expr.getRHS())) {
if (rhsConst.getValue() <= 0) {
emitError(loc, "modulo by non-positive value is not supported");
return nullptr;
}
+
+ // Special case: x mod n where n is a power of 2 can be optimized to x & (n-1)
+ int64_t rhsValue = rhsConst.getValue();
+ if (rhsValue > 0 && (rhsValue & (rhsValue - 1)) == 0) {
+ auto lhs = visit(expr.getLHS());
+ assert(lhs && "unexpected affine expr lowering failure");
+
+ Value maskCst = builder.create<arith::ConstantIndexOp>(loc, rhsValue - 1);
+ return builder.create<arith::AndIOp>(loc, lhs, maskCst);
+ }
}
auto lhs = visit(expr.getLHS());
diff --git a/mlir/test/Conversion/AffineToStandard/lower-affine.mlir b/mlir/test/Conversion/AffineToStandard/lower-affine.mlir
index 550ea71882e14..07f7c64fe6ea5 100644
--- a/mlir/test/Conversion/AffineToStandard/lower-affine.mlir
+++ b/mlir/test/Conversion/AffineToStandard/lower-affine.mlir
@@ -927,3 +927,12 @@ func.func @affine_parallel_with_reductions_i64(%arg0: memref<3x3xi64>, %arg1: me
// CHECK: scf.reduce.return %[[RES]] : i64
// CHECK: }
// CHECK: }
+
+#map_mod_8 = affine_map<(i) -> (i mod 8)>
+// CHECK-LABEL: func @affine_apply_mod_8
+func.func @affine_apply_mod_8(%arg0 : index) -> (index) {
+ // CHECK-NEXT: %[[c7:.*]] = arith.constant 7 : index
+ // CHECK-NEXT: %[[v0:.*]] = arith.andi %arg0, %[[c7]] : index
+ %0 = affine.apply #map_mod_8 (%arg0)
+ return %0 : index
+}
|
You can test this locally with the following command:git-clang-format --diff HEAD~1 HEAD --extensions cpp -- mlir/lib/Dialect/Affine/Utils/Utils.cpp View the diff from clang-format here.diff --git a/mlir/lib/Dialect/Affine/Utils/Utils.cpp b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
index de9c78747..0cffe52dd 100644
--- a/mlir/lib/Dialect/Affine/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
@@ -81,21 +81,24 @@ public:
/// negative = a < 0 in
/// select negative, remainder + b, remainder.
///
- /// Special case for power of 2: use bitwise AND (x & (n-1)) for non-negative x.
+ /// Special case for power of 2: use bitwise AND (x & (n-1)) for non-negative
+ /// x.
Value visitModExpr(AffineBinaryOpExpr expr) {
if (auto rhsConst = dyn_cast<AffineConstantExpr>(expr.getRHS())) {
if (rhsConst.getValue() <= 0) {
emitError(loc, "modulo by non-positive value is not supported");
return nullptr;
}
-
- // Special case: x mod n where n is a power of 2 can be optimized to x & (n-1)
+
+ // Special case: x mod n where n is a power of 2 can be optimized to x &
+ // (n-1)
int64_t rhsValue = rhsConst.getValue();
if (rhsValue > 0 && (rhsValue & (rhsValue - 1)) == 0) {
auto lhs = visit(expr.getLHS());
assert(lhs && "unexpected affine expr lowering failure");
-
- Value maskCst = builder.create<arith::ConstantIndexOp>(loc, rhsValue - 1);
+
+ Value maskCst =
+ builder.create<arith::ConstantIndexOp>(loc, rhsValue - 1);
return builder.create<arith::AndIOp>(loc, lhs, maskCst);
}
}
|
9792917
to
db4de47
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why would we not do this on arith.remsi / arith.remui?
@Groverkss you mean add a canonicalizer to arith.remsi / arith.remui? In that case do we also need to handle cmp + select? The case I have in mind is incrementing |
According to https://mlir.llvm.org/docs/Dialects/Affine/#affinefor-affineaffineforop: mod is the modulo operation: since its second argument is always positive, its results are always positive in our usage.
I guess we could implement different canonicalizers for |
You are right - this is accurate. The result of an affine mod expression is guaranteed to be positive. So, exploiting this information right away will lead to a single |
If you had I'm actually surprised we missed this optimization for 6 years! :-) I'm in favor of adding this - perhaps under a flag if needed. CC: @ftynse for review as well. |
Change request comment already responded to, and it's time to take another look. Rerequesting review.
func.func @affine_apply_mod_8(%arg0 : index) -> (index) { | ||
// CHECK-NEXT: %[[c7:.*]] = arith.constant 7 : index | ||
// CHECK-NEXT: %[[v0:.*]] = arith.andi %arg0, %[[c7]] : index | ||
%0 = affine.apply #map_mod_8 (%arg0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can put the map inline for better readability.
// CHECK: } | ||
|
||
#map_mod_8 = affine_map<(i) -> (i mod 8)> | ||
// CHECK-LABEL: func @affine_apply_mod_8 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you also check if using ... mod 1
doesn't lead to any unexpected behavior.
} | ||
|
||
// Special case: x mod n where n is a power of 2 can be optimized to x & | ||
// (n-1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Terminate all comments with a full stop. LLVM style.
/// negative = a < 0 in | ||
/// select negative, remainder + b, remainder. | ||
/// | ||
/// Special case for power of 2: use bitwise AND (x & (n-1)) for non-negative |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
power of 2 RHS
// Special case: x mod n where n is a power of 2 can be optimized to x & | ||
// (n-1) | ||
int64_t rhsValue = rhsConst.getValue(); | ||
if (rhsValue > 0 && (rhsValue & (rhsValue - 1)) == 0) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But we've already returned for all rhs values <= 0 at L90. rhsValue
is guaranteed to be positive now. This check is unnecessary.
|
||
// Special case: x mod n where n is a power of 2 can be optimized to x & | ||
// (n-1) | ||
int64_t rhsValue = rhsConst.getValue(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Move this assignment to above L88 to avoid multiple calls to getValue().
Sorry, i didn't look at this until i got a "rerequest review" notification. Thanks for answering why it cannot be done on remsi/remui, that sounds fair, and I wouldve dismissed my review.
So this optimization works for I usually prefer lowering to an op, and then letting the op canonicalize itself to something better, based on the lowered op's properties, than special casing the lowering, but that's a preference and I don't know what the correct way to do it is, so i'll not block this. |
Absolutely. No range info is needed - this PR already does it.
Typically, an op shouldn't be added to such a dialect (like |
@sherylll - are you available to take this forward? |
This patch adds a special-case optimization in the affine-to-standard lowering pass to replace modulo operations by constant powers of two with a single bitwise AND operation. This reduces instruction count and improves performance for common cases like
x mod 2
.