1313#include " mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
1414#include " mlir/Analysis/DataFlow/IntegerRangeAnalysis.h"
1515#include " mlir/Dialect/Arith/IR/Arith.h"
16- #include " mlir/Transforms/DialectConversion.h"
16+ #include " mlir/IR/PatternMatch.h"
17+ #include " mlir/Transforms/GreedyPatternRewriteDriver.h"
1718
1819namespace mlir {
1920namespace arith {
@@ -85,35 +86,60 @@ static CmpIPredicate toUnsignedPred(CmpIPredicate pred) {
8586}
8687
8788namespace {
89+ class DataFlowListener : public RewriterBase ::Listener {
90+ public:
91+ DataFlowListener (DataFlowSolver &s) : s(s) {}
92+
93+ protected:
94+ void notifyOperationErased (Operation *op) override {
95+ s.eraseState (s.getProgramPointAfter (op));
96+ for (Value res : op->getResults ())
97+ s.eraseState (res);
98+ }
99+
100+ DataFlowSolver &s;
101+ };
102+
88103template <typename Signed, typename Unsigned>
89- struct ConvertOpToUnsigned : OpConversionPattern<Signed> {
90- using OpConversionPattern<Signed>::OpConversionPattern;
104+ struct ConvertOpToUnsigned final : public OpRewritePattern<Signed> {
105+ ConvertOpToUnsigned (MLIRContext *context, DataFlowSolver &s)
106+ : OpRewritePattern<Signed>(context), solver(s) {}
91107
92- LogicalResult matchAndRewrite (Signed op, typename Signed::Adaptor adaptor,
93- ConversionPatternRewriter &rw) const override {
94- rw.replaceOpWithNewOp <Unsigned>(op, op->getResultTypes (),
95- adaptor.getOperands (), op->getAttrs ());
108+ LogicalResult matchAndRewrite (Signed op, PatternRewriter &rw) const override {
109+ if (failed (
110+ staticallyNonNegative (this ->solver , static_cast <Operation *>(op))))
111+ return failure ();
112+
113+ rw.replaceOpWithNewOp <Unsigned>(op, op->getResultTypes (), op->getOperands (),
114+ op->getAttrs ());
96115 return success ();
97116 }
117+
118+ private:
119+ DataFlowSolver &solver;
98120};
99121
100- struct ConvertCmpIToUnsigned : OpConversionPattern<CmpIOp> {
101- using OpConversionPattern<CmpIOp>::OpConversionPattern;
122+ struct ConvertCmpIToUnsigned final : public OpRewritePattern<CmpIOp> {
123+ ConvertCmpIToUnsigned (MLIRContext *context, DataFlowSolver &s)
124+ : OpRewritePattern<CmpIOp>(context), solver(s) {}
125+
126+ LogicalResult matchAndRewrite (CmpIOp op, PatternRewriter &rw) const override {
127+ if (failed (isCmpIConvertable (this ->solver , op)))
128+ return failure ();
102129
103- LogicalResult matchAndRewrite (CmpIOp op, CmpIOpAdaptor adaptor,
104- ConversionPatternRewriter &rw) const override {
105130 rw.replaceOpWithNewOp <CmpIOp>(op, toUnsignedPred (op.getPredicate ()),
106131 op.getLhs (), op.getRhs ());
107132 return success ();
108133 }
134+
135+ private:
136+ DataFlowSolver &solver;
109137};
110138
111139struct ArithUnsignedWhenEquivalentPass
112140 : public arith::impl::ArithUnsignedWhenEquivalentBase<
113141 ArithUnsignedWhenEquivalentPass> {
114- // / Implementation structure: first find all equivalent ops and collect them,
115- // / then perform all the rewrites in a second pass over the target op. This
116- // / ensures that analysis results are not invalidated during rewriting.
142+
117143 void runOnOperation () override {
118144 Operation *op = getOperation ();
119145 MLIRContext *ctx = op->getContext ();
@@ -123,35 +149,32 @@ struct ArithUnsignedWhenEquivalentPass
123149 if (failed (solver.initializeAndRun (op)))
124150 return signalPassFailure ();
125151
126- ConversionTarget target (*ctx);
127- target.addLegalDialect <ArithDialect>();
128- target.addDynamicallyLegalOp <DivSIOp, CeilDivSIOp, FloorDivSIOp, RemSIOp,
129- MinSIOp, MaxSIOp, ExtSIOp>(
130- [&solver](Operation *op) -> std::optional<bool > {
131- return failed (staticallyNonNegative (solver, op));
132- });
133- target.addDynamicallyLegalOp <CmpIOp>(
134- [&solver](CmpIOp op) -> std::optional<bool > {
135- return failed (isCmpIConvertable (solver, op));
136- });
152+ DataFlowListener listener (solver);
137153
138154 RewritePatternSet patterns (ctx);
139- patterns.add <ConvertOpToUnsigned<DivSIOp, DivUIOp>,
140- ConvertOpToUnsigned<CeilDivSIOp, CeilDivUIOp>,
141- ConvertOpToUnsigned<FloorDivSIOp, DivUIOp>,
142- ConvertOpToUnsigned<RemSIOp, RemUIOp>,
143- ConvertOpToUnsigned<MinSIOp, MinUIOp>,
144- ConvertOpToUnsigned<MaxSIOp, MaxUIOp>,
145- ConvertOpToUnsigned<ExtSIOp, ExtUIOp>, ConvertCmpIToUnsigned>(
146- ctx);
147-
148- if (failed (applyPartialConversion (op, target, std::move (patterns)))) {
155+ populateUnsignedWhenEquivalentPatterns (patterns, solver);
156+
157+ GreedyRewriteConfig config;
158+ config.listener = &listener;
159+
160+ if (failed (applyPatternsAndFoldGreedily (op, std::move (patterns), config)))
149161 signalPassFailure ();
150- }
151162 }
152163};
153164} // end anonymous namespace
154165
166+ void mlir::arith::populateUnsignedWhenEquivalentPatterns (
167+ RewritePatternSet &patterns, DataFlowSolver &solver) {
168+ patterns.add <ConvertOpToUnsigned<DivSIOp, DivUIOp>,
169+ ConvertOpToUnsigned<CeilDivSIOp, CeilDivUIOp>,
170+ ConvertOpToUnsigned<FloorDivSIOp, DivUIOp>,
171+ ConvertOpToUnsigned<RemSIOp, RemUIOp>,
172+ ConvertOpToUnsigned<MinSIOp, MinUIOp>,
173+ ConvertOpToUnsigned<MaxSIOp, MaxUIOp>,
174+ ConvertOpToUnsigned<ExtSIOp, ExtUIOp>, ConvertCmpIToUnsigned>(
175+ patterns.getContext (), solver);
176+ }
177+
155178std::unique_ptr<Pass> mlir::arith::createArithUnsignedWhenEquivalentPass () {
156179 return std::make_unique<ArithUnsignedWhenEquivalentPass>();
157180}
0 commit comments