-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[mlir] UnsignedWhenEquivalent: use greedy rewriter instead of dialect conversion #112454
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -13,7 +13,8 @@ | |||||
| #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h" | ||||||
| #include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h" | ||||||
| #include "mlir/Dialect/Arith/IR/Arith.h" | ||||||
| #include "mlir/Transforms/DialectConversion.h" | ||||||
| #include "mlir/IR/PatternMatch.h" | ||||||
| #include "mlir/Transforms/GreedyPatternRewriteDriver.h" | ||||||
|
|
||||||
| namespace mlir { | ||||||
| namespace arith { | ||||||
|
|
@@ -29,6 +30,9 @@ using namespace mlir::dataflow; | |||||
| /// Succeeds when a value is statically non-negative in that it has a lower | ||||||
| /// bound on its value (if it is treated as signed) and that bound is | ||||||
| /// non-negative. | ||||||
| // TODO: IntegerRangeAnalysis internally assumes index is 64bit and this pattern | ||||||
| // relies on this. These transformations may not be valid for 32bit index, | ||||||
| // need more investigation. | ||||||
| static LogicalResult staticallyNonNegative(DataFlowSolver &solver, Value v) { | ||||||
| auto *result = solver.lookupState<IntegerValueRangeLattice>(v); | ||||||
| if (!result || result->getValue().isUninitialized()) | ||||||
|
|
@@ -85,35 +89,60 @@ static CmpIPredicate toUnsignedPred(CmpIPredicate pred) { | |||||
| } | ||||||
|
|
||||||
| namespace { | ||||||
| class DataFlowListener : public RewriterBase::Listener { | ||||||
| public: | ||||||
| DataFlowListener(DataFlowSolver &s) : s(s) {} | ||||||
|
|
||||||
| protected: | ||||||
| void notifyOperationErased(Operation *op) override { | ||||||
| s.eraseState(s.getProgramPointAfter(op)); | ||||||
| for (Value res : op->getResults()) | ||||||
| s.eraseState(res); | ||||||
| } | ||||||
|
|
||||||
| DataFlowSolver &s; | ||||||
| }; | ||||||
|
|
||||||
| template <typename Signed, typename Unsigned> | ||||||
| struct ConvertOpToUnsigned : OpConversionPattern<Signed> { | ||||||
| using OpConversionPattern<Signed>::OpConversionPattern; | ||||||
| struct ConvertOpToUnsigned final : public OpRewritePattern<Signed> { | ||||||
|
||||||
| struct ConvertOpToUnsigned final : public OpRewritePattern<Signed> { | |
| struct ConvertOpToUnsigned final : OpRewritePattern<Signed> { |
also below
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not exactly sure where to leave this comment so will just do so here: the backing range analysis makes an assumption that IndexType is 64bits. While mostly harmless to the analysis, this pattern builds on that to apply an optimization that is not valid for a 32bit IndexType.
I don't have a good solution to this but it took me some sleuthing in an earlier incarnation to understand.
Commenting here because "staticallyNonNegative" is only applicable to 64bit IndexType. Probably should at least call for a comment somewhere.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added a comment.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we expect to have to iterate to convergence here? Otherwise can we set options to limit to a single iteration?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think, it should finish in single iteration.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, adding maxIterations = 1 causing it to fail to converge. I think I will leave things as is for now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You probably also need to change it to top down iteration in order to get the same convergence behavior as before.
I've definitely seen combined passes that are doing other optimizations along with unsigned conversions require multiple iterations to converge (and be more efficient with bottom up iteration), but I expect that this simple test pass just needs one top down pass through the IR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It fails to converge (i.e. applyPatternsAndFoldGreedily returns failure) even with
config.maxIterations = 1;
config.useTopDownTraversal = true;
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is mostly test pass anyway so I don't see much problem here. Downstream we plan to combine it with other patterns. And using greedy driver here may not be ideal, but it's still better than current dialect conversion driver.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If this is a test pass, it should be moved to the test folder and be named accordingly though.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This isn't intended as a test pass - you're meant to be able to run this (barring the usual philosophical disagreements about even having non-test passes upstream)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wrote this code using the dialect converter because I wanted a one-shot "walk this function exactly once and apply matching patterns"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This isn't intended as a test pass
Well, then my point stands: we shouldn't involve the greedy rewriter here.
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -12,7 +12,7 @@ | |||||
| // CHECK: arith.cmpi slt | ||||||
| // CHECK: arith.cmpi sge | ||||||
| // CHECK: arith.cmpi sgt | ||||||
| func.func @not_with_maybe_overflow(%arg0 : i32) { | ||||||
| func.func @not_with_maybe_overflow(%arg0 : i32) -> (i32, i32, i32, i32, i32, i32, i64, i1, i1, i1, i1) { | ||||||
| %ci32_smax = arith.constant 0x7fffffff : i32 | ||||||
| %c1 = arith.constant 1 : i32 | ||||||
| %c4 = arith.constant 4 : i32 | ||||||
|
|
@@ -29,7 +29,7 @@ func.func @not_with_maybe_overflow(%arg0 : i32) { | |||||
| %10 = arith.cmpi slt, %1, %c4 : i32 | ||||||
| %11 = arith.cmpi sge, %1, %c4 : i32 | ||||||
| %12 = arith.cmpi sgt, %1, %c4 : i32 | ||||||
| func.return | ||||||
| func.return %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12 : i32, i32, i32, i32, i32, i32, i64, i1, i1, i1, i1 | ||||||
| } | ||||||
|
|
||||||
| // CHECK-LABEL: func @yes_with_no_overflow | ||||||
|
|
@@ -44,7 +44,7 @@ func.func @not_with_maybe_overflow(%arg0 : i32) { | |||||
| // CHECK: arith.cmpi ult | ||||||
| // CHECK: arith.cmpi uge | ||||||
| // CHECK: arith.cmpi ugt | ||||||
| func.func @yes_with_no_overflow(%arg0 : i32) { | ||||||
| func.func @yes_with_no_overflow(%arg0 : i32) -> (i32, i32, i32, i32, i32, i32, i64, i1, i1, i1, i1) { | ||||||
| %ci32_almost_smax = arith.constant 0x7ffffffe : i32 | ||||||
| %c1 = arith.constant 1 : i32 | ||||||
| %c4 = arith.constant 4 : i32 | ||||||
|
|
@@ -61,7 +61,7 @@ func.func @yes_with_no_overflow(%arg0 : i32) { | |||||
| %10 = arith.cmpi slt, %1, %c4 : i32 | ||||||
| %11 = arith.cmpi sge, %1, %c4 : i32 | ||||||
| %12 = arith.cmpi sgt, %1, %c4 : i32 | ||||||
| func.return | ||||||
| func.return %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12 : i32, i32, i32, i32, i32, i32, i64, i1, i1, i1, i1 | ||||||
| } | ||||||
|
|
||||||
| // CHECK-LABEL: func @preserves_structure | ||||||
|
|
@@ -90,20 +90,20 @@ func.func @preserves_structure(%arg0 : memref<8xindex>) { | |||||
| func.func private @external() -> i8 | ||||||
|
|
||||||
| // CHECK-LABEL: @dead_code | ||||||
| func.func @dead_code() { | ||||||
| func.func @dead_code() -> i8 { | ||||||
| %0 = call @external() : () -> i8 | ||||||
| // CHECK: arith.floordivsi | ||||||
| %1 = arith.floordivsi %0, %0 : i8 | ||||||
| return | ||||||
| return %1 : i8 | ||||||
| } | ||||||
|
|
||||||
| // Make sure not crash. | ||||||
| // CHECK-LABEL: @no_integer_or_index | ||||||
| func.func @no_integer_or_index() { | ||||||
| func.func @no_integer_or_index(%arg0 : vector<1xi32> ) -> vector<1xi1> { | ||||||
|
||||||
| func.func @no_integer_or_index(%arg0 : vector<1xi32> ) -> vector<1xi1> { | |
| func.func @no_integer_or_index(%arg0: vector<1xi32>) -> vector<1xi1> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure this is true - all that
IntegerRangeAnalysisdoes is storeindexas 64-bit. The implementations for various ops onindexought to handle both cases - if they don't might be a bug