-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[mlir] UnsignedWhenEquivalent: use greedy rewriter instead of dialect conversion #112454
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -13,7 +13,8 @@ | |
| #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h" | ||
| #include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h" | ||
| #include "mlir/Dialect/Arith/IR/Arith.h" | ||
| #include "mlir/Transforms/DialectConversion.h" | ||
| #include "mlir/IR/PatternMatch.h" | ||
| #include "mlir/Transforms/GreedyPatternRewriteDriver.h" | ||
|
|
||
| namespace mlir { | ||
| namespace arith { | ||
|
|
@@ -85,35 +86,60 @@ static CmpIPredicate toUnsignedPred(CmpIPredicate pred) { | |
| } | ||
|
|
||
| namespace { | ||
| class DataFlowListener : public RewriterBase::Listener { | ||
| public: | ||
| DataFlowListener(DataFlowSolver &s) : s(s) {} | ||
|
|
||
| protected: | ||
| void notifyOperationErased(Operation *op) override { | ||
| s.eraseState(s.getProgramPointAfter(op)); | ||
| for (Value res : op->getResults()) | ||
| s.eraseState(res); | ||
| } | ||
|
|
||
| DataFlowSolver &s; | ||
| }; | ||
|
|
||
| template <typename Signed, typename Unsigned> | ||
| struct ConvertOpToUnsigned : OpConversionPattern<Signed> { | ||
| using OpConversionPattern<Signed>::OpConversionPattern; | ||
| struct ConvertOpToUnsigned final : public OpRewritePattern<Signed> { | ||
| ConvertOpToUnsigned(MLIRContext *context, DataFlowSolver &s) | ||
| : OpRewritePattern<Signed>(context), solver(s) {} | ||
|
|
||
| LogicalResult matchAndRewrite(Signed op, typename Signed::Adaptor adaptor, | ||
| ConversionPatternRewriter &rw) const override { | ||
| rw.replaceOpWithNewOp<Unsigned>(op, op->getResultTypes(), | ||
| adaptor.getOperands(), op->getAttrs()); | ||
| LogicalResult matchAndRewrite(Signed op, PatternRewriter &rw) const override { | ||
| if (failed( | ||
| staticallyNonNegative(this->solver, static_cast<Operation *>(op)))) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not exactly sure where to leave this comment so will just do so here: the backing range analysis makes an assumption that IndexType is 64bits. While mostly harmless to the analysis, this pattern builds on that to apply an optimization that is not valid for a 32bit IndexType. I don't have a good solution to this but it took me some sleuthing in an earlier incarnation to understand. Commenting here because "staticallyNonNegative" is only applicable to 64bit IndexType. Probably should at least call for a comment somewhere.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. added a comment. |
||
| return failure(); | ||
|
|
||
| rw.replaceOpWithNewOp<Unsigned>(op, op->getResultTypes(), op->getOperands(), | ||
| op->getAttrs()); | ||
| return success(); | ||
| } | ||
|
|
||
| private: | ||
| DataFlowSolver &solver; | ||
| }; | ||
|
|
||
| struct ConvertCmpIToUnsigned : OpConversionPattern<CmpIOp> { | ||
| using OpConversionPattern<CmpIOp>::OpConversionPattern; | ||
| struct ConvertCmpIToUnsigned final : public OpRewritePattern<CmpIOp> { | ||
| ConvertCmpIToUnsigned(MLIRContext *context, DataFlowSolver &s) | ||
| : OpRewritePattern<CmpIOp>(context), solver(s) {} | ||
|
|
||
| LogicalResult matchAndRewrite(CmpIOp op, PatternRewriter &rw) const override { | ||
| if (failed(isCmpIConvertable(this->solver, op))) | ||
| return failure(); | ||
|
|
||
| LogicalResult matchAndRewrite(CmpIOp op, CmpIOpAdaptor adaptor, | ||
| ConversionPatternRewriter &rw) const override { | ||
| rw.replaceOpWithNewOp<CmpIOp>(op, toUnsignedPred(op.getPredicate()), | ||
| op.getLhs(), op.getRhs()); | ||
| return success(); | ||
| } | ||
|
|
||
| private: | ||
| DataFlowSolver &solver; | ||
| }; | ||
|
|
||
| struct ArithUnsignedWhenEquivalentPass | ||
| : public arith::impl::ArithUnsignedWhenEquivalentBase< | ||
| ArithUnsignedWhenEquivalentPass> { | ||
| /// Implementation structure: first find all equivalent ops and collect them, | ||
| /// then perform all the rewrites in a second pass over the target op. This | ||
| /// ensures that analysis results are not invalidated during rewriting. | ||
|
|
||
| void runOnOperation() override { | ||
| Operation *op = getOperation(); | ||
| MLIRContext *ctx = op->getContext(); | ||
|
|
@@ -123,35 +149,32 @@ struct ArithUnsignedWhenEquivalentPass | |
| if (failed(solver.initializeAndRun(op))) | ||
| return signalPassFailure(); | ||
|
|
||
| ConversionTarget target(*ctx); | ||
| target.addLegalDialect<ArithDialect>(); | ||
| target.addDynamicallyLegalOp<DivSIOp, CeilDivSIOp, FloorDivSIOp, RemSIOp, | ||
| MinSIOp, MaxSIOp, ExtSIOp>( | ||
| [&solver](Operation *op) -> std::optional<bool> { | ||
| return failed(staticallyNonNegative(solver, op)); | ||
| }); | ||
| target.addDynamicallyLegalOp<CmpIOp>( | ||
| [&solver](CmpIOp op) -> std::optional<bool> { | ||
| return failed(isCmpIConvertable(solver, op)); | ||
| }); | ||
| DataFlowListener listener(solver); | ||
|
|
||
| RewritePatternSet patterns(ctx); | ||
| patterns.add<ConvertOpToUnsigned<DivSIOp, DivUIOp>, | ||
| ConvertOpToUnsigned<CeilDivSIOp, CeilDivUIOp>, | ||
| ConvertOpToUnsigned<FloorDivSIOp, DivUIOp>, | ||
| ConvertOpToUnsigned<RemSIOp, RemUIOp>, | ||
| ConvertOpToUnsigned<MinSIOp, MinUIOp>, | ||
| ConvertOpToUnsigned<MaxSIOp, MaxUIOp>, | ||
| ConvertOpToUnsigned<ExtSIOp, ExtUIOp>, ConvertCmpIToUnsigned>( | ||
| ctx); | ||
|
|
||
| if (failed(applyPartialConversion(op, target, std::move(patterns)))) { | ||
| populateUnsignedWhenEquivalentPatterns(patterns, solver); | ||
|
|
||
| GreedyRewriteConfig config; | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we expect to have to iterate to convergence here? Otherwise can we set options to limit to a single iteration?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think, it should finish in single iteration.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Actually, adding
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You probably also need to change it to top down iteration in order to get the same convergence behavior as before. I've definitely seen combined passes that are doing other optimizations along with unsigned conversions require multiple iterations to converge (and be more efficient with bottom up iteration), but I expect that this simple test pass just needs one top down pass through the IR.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It fails to converge (i.e.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is mostly test pass anyway so I don't see much problem here. Downstream we plan to combine it with other patterns. And using greedy driver here may not be ideal, but it's still better than current dialect conversion driver.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If this is a test pass, it should be moved to the test folder and be named accordingly though.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This isn't intended as a test pass - you're meant to be able to run this (barring the usual philosophical disagreements about even having non-test passes upstream)
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wrote this code using the dialect converter because I wanted a one-shot "walk this function exactly once and apply matching patterns"
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Well, then my point stands: we shouldn't involve the greedy rewriter here. |
||
| config.listener = &listener; | ||
|
|
||
| if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns), config))) | ||
| signalPassFailure(); | ||
| } | ||
| } | ||
| }; | ||
| } // end anonymous namespace | ||
|
|
||
| void mlir::arith::populateUnsignedWhenEquivalentPatterns( | ||
| RewritePatternSet &patterns, DataFlowSolver &solver) { | ||
| patterns.add<ConvertOpToUnsigned<DivSIOp, DivUIOp>, | ||
| ConvertOpToUnsigned<CeilDivSIOp, CeilDivUIOp>, | ||
| ConvertOpToUnsigned<FloorDivSIOp, DivUIOp>, | ||
| ConvertOpToUnsigned<RemSIOp, RemUIOp>, | ||
| ConvertOpToUnsigned<MinSIOp, MinUIOp>, | ||
| ConvertOpToUnsigned<MaxSIOp, MaxUIOp>, | ||
| ConvertOpToUnsigned<ExtSIOp, ExtUIOp>, ConvertCmpIToUnsigned>( | ||
| patterns.getContext(), solver); | ||
| } | ||
|
|
||
| std::unique_ptr<Pass> mlir::arith::createArithUnsignedWhenEquivalentPass() { | ||
| return std::make_unique<ArithUnsignedWhenEquivalentPass>(); | ||
| } | ||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -12,7 +12,7 @@ | |||||
| // CHECK: arith.cmpi slt | ||||||
| // CHECK: arith.cmpi sge | ||||||
| // CHECK: arith.cmpi sgt | ||||||
| func.func @not_with_maybe_overflow(%arg0 : i32) { | ||||||
| func.func @not_with_maybe_overflow(%arg0 : i32) -> (i32, i32, i32, i32, i32, i32, i64, i1, i1, i1, i1) { | ||||||
| %ci32_smax = arith.constant 0x7fffffff : i32 | ||||||
| %c1 = arith.constant 1 : i32 | ||||||
| %c4 = arith.constant 4 : i32 | ||||||
|
|
@@ -29,7 +29,7 @@ func.func @not_with_maybe_overflow(%arg0 : i32) { | |||||
| %10 = arith.cmpi slt, %1, %c4 : i32 | ||||||
| %11 = arith.cmpi sge, %1, %c4 : i32 | ||||||
| %12 = arith.cmpi sgt, %1, %c4 : i32 | ||||||
| func.return | ||||||
| func.return %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12 : i32, i32, i32, i32, i32, i32, i64, i1, i1, i1, i1 | ||||||
| } | ||||||
|
|
||||||
| // CHECK-LABEL: func @yes_with_no_overflow | ||||||
|
|
@@ -44,7 +44,7 @@ func.func @not_with_maybe_overflow(%arg0 : i32) { | |||||
| // CHECK: arith.cmpi ult | ||||||
| // CHECK: arith.cmpi uge | ||||||
| // CHECK: arith.cmpi ugt | ||||||
| func.func @yes_with_no_overflow(%arg0 : i32) { | ||||||
| func.func @yes_with_no_overflow(%arg0 : i32) -> (i32, i32, i32, i32, i32, i32, i64, i1, i1, i1, i1) { | ||||||
| %ci32_almost_smax = arith.constant 0x7ffffffe : i32 | ||||||
| %c1 = arith.constant 1 : i32 | ||||||
| %c4 = arith.constant 4 : i32 | ||||||
|
|
@@ -61,7 +61,7 @@ func.func @yes_with_no_overflow(%arg0 : i32) { | |||||
| %10 = arith.cmpi slt, %1, %c4 : i32 | ||||||
| %11 = arith.cmpi sge, %1, %c4 : i32 | ||||||
| %12 = arith.cmpi sgt, %1, %c4 : i32 | ||||||
| func.return | ||||||
| func.return %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12 : i32, i32, i32, i32, i32, i32, i64, i1, i1, i1, i1 | ||||||
| } | ||||||
|
|
||||||
| // CHECK-LABEL: func @preserves_structure | ||||||
|
|
@@ -90,20 +90,20 @@ func.func @preserves_structure(%arg0 : memref<8xindex>) { | |||||
| func.func private @external() -> i8 | ||||||
|
|
||||||
| // CHECK-LABEL: @dead_code | ||||||
| func.func @dead_code() { | ||||||
| func.func @dead_code() -> i8 { | ||||||
| %0 = call @external() : () -> i8 | ||||||
| // CHECK: arith.floordivsi | ||||||
| %1 = arith.floordivsi %0, %0 : i8 | ||||||
| return | ||||||
| return %1 : i8 | ||||||
| } | ||||||
|
|
||||||
| // Make sure not crash. | ||||||
| // CHECK-LABEL: @no_integer_or_index | ||||||
| func.func @no_integer_or_index() { | ||||||
| func.func @no_integer_or_index(%arg0 : vector<1xi32> ) -> vector<1xi1> { | ||||||
|
||||||
| func.func @no_integer_or_index(%arg0 : vector<1xi32> ) -> vector<1xi1> { | |
| func.func @no_integer_or_index(%arg0: vector<1xi32>) -> vector<1xi1> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit:
also below