1313#include " mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
1414#include " mlir/Analysis/DataFlow/IntegerRangeAnalysis.h"
1515#include " mlir/Dialect/Arith/IR/Arith.h"
16- #include " mlir/Transforms/DialectConversion.h"
16+ #include " mlir/IR/PatternMatch.h"
17+ #include " mlir/Transforms/GreedyPatternRewriteDriver.h"
1718
1819namespace mlir {
1920namespace arith {
@@ -29,6 +30,9 @@ using namespace mlir::dataflow;
2930// / Succeeds when a value is statically non-negative in that it has a lower
3031// / bound on its value (if it is treated as signed) and that bound is
3132// / non-negative.
33+ // TODO: IntegerRangeAnalysis internally assumes index is 64bit and this pattern
34+ // relies on this. These transformations may not be valid for 32bit index,
35+ // need more investigation.
3236static LogicalResult staticallyNonNegative (DataFlowSolver &solver, Value v) {
3337 auto *result = solver.lookupState <IntegerValueRangeLattice>(v);
3438 if (!result || result->getValue ().isUninitialized ())
@@ -85,35 +89,60 @@ static CmpIPredicate toUnsignedPred(CmpIPredicate pred) {
8589}
8690
8791namespace {
92+ class DataFlowListener : public RewriterBase ::Listener {
93+ public:
94+ DataFlowListener (DataFlowSolver &s) : s(s) {}
95+
96+ protected:
97+ void notifyOperationErased (Operation *op) override {
98+ s.eraseState (s.getProgramPointAfter (op));
99+ for (Value res : op->getResults ())
100+ s.eraseState (res);
101+ }
102+
103+ DataFlowSolver &s;
104+ };
105+
88106template <typename Signed, typename Unsigned>
89- struct ConvertOpToUnsigned : OpConversionPattern<Signed> {
90- using OpConversionPattern<Signed>::OpConversionPattern;
107+ struct ConvertOpToUnsigned final : OpRewritePattern<Signed> {
108+ ConvertOpToUnsigned (MLIRContext *context, DataFlowSolver &s)
109+ : OpRewritePattern<Signed>(context), solver(s) {}
91110
92- LogicalResult matchAndRewrite (Signed op, typename Signed::Adaptor adaptor,
93- ConversionPatternRewriter &rw) const override {
94- rw.replaceOpWithNewOp <Unsigned>(op, op->getResultTypes (),
95- adaptor.getOperands (), op->getAttrs ());
111+ LogicalResult matchAndRewrite (Signed op, PatternRewriter &rw) const override {
112+ if (failed (
113+ staticallyNonNegative (this ->solver , static_cast <Operation *>(op))))
114+ return failure ();
115+
116+ rw.replaceOpWithNewOp <Unsigned>(op, op->getResultTypes (), op->getOperands (),
117+ op->getAttrs ());
96118 return success ();
97119 }
120+
121+ private:
122+ DataFlowSolver &solver;
98123};
99124
100- struct ConvertCmpIToUnsigned : OpConversionPattern<CmpIOp> {
101- using OpConversionPattern<CmpIOp>::OpConversionPattern;
125+ struct ConvertCmpIToUnsigned final : OpRewritePattern<CmpIOp> {
126+ ConvertCmpIToUnsigned (MLIRContext *context, DataFlowSolver &s)
127+ : OpRewritePattern<CmpIOp>(context), solver(s) {}
128+
129+ LogicalResult matchAndRewrite (CmpIOp op, PatternRewriter &rw) const override {
130+ if (failed (isCmpIConvertable (this ->solver , op)))
131+ return failure ();
102132
103- LogicalResult matchAndRewrite (CmpIOp op, CmpIOpAdaptor adaptor,
104- ConversionPatternRewriter &rw) const override {
105133 rw.replaceOpWithNewOp <CmpIOp>(op, toUnsignedPred (op.getPredicate ()),
106134 op.getLhs (), op.getRhs ());
107135 return success ();
108136 }
137+
138+ private:
139+ DataFlowSolver &solver;
109140};
110141
111142struct ArithUnsignedWhenEquivalentPass
112143 : public arith::impl::ArithUnsignedWhenEquivalentBase<
113144 ArithUnsignedWhenEquivalentPass> {
114- // / Implementation structure: first find all equivalent ops and collect them,
115- // / then perform all the rewrites in a second pass over the target op. This
116- // / ensures that analysis results are not invalidated during rewriting.
145+
117146 void runOnOperation () override {
118147 Operation *op = getOperation ();
119148 MLIRContext *ctx = op->getContext ();
@@ -123,35 +152,32 @@ struct ArithUnsignedWhenEquivalentPass
123152 if (failed (solver.initializeAndRun (op)))
124153 return signalPassFailure ();
125154
126- ConversionTarget target (*ctx);
127- target.addLegalDialect <ArithDialect>();
128- target.addDynamicallyLegalOp <DivSIOp, CeilDivSIOp, FloorDivSIOp, RemSIOp,
129- MinSIOp, MaxSIOp, ExtSIOp>(
130- [&solver](Operation *op) -> std::optional<bool > {
131- return failed (staticallyNonNegative (solver, op));
132- });
133- target.addDynamicallyLegalOp <CmpIOp>(
134- [&solver](CmpIOp op) -> std::optional<bool > {
135- return failed (isCmpIConvertable (solver, op));
136- });
155+ DataFlowListener listener (solver);
137156
138157 RewritePatternSet patterns (ctx);
139- patterns.add <ConvertOpToUnsigned<DivSIOp, DivUIOp>,
140- ConvertOpToUnsigned<CeilDivSIOp, CeilDivUIOp>,
141- ConvertOpToUnsigned<FloorDivSIOp, DivUIOp>,
142- ConvertOpToUnsigned<RemSIOp, RemUIOp>,
143- ConvertOpToUnsigned<MinSIOp, MinUIOp>,
144- ConvertOpToUnsigned<MaxSIOp, MaxUIOp>,
145- ConvertOpToUnsigned<ExtSIOp, ExtUIOp>, ConvertCmpIToUnsigned>(
146- ctx);
147-
148- if (failed (applyPartialConversion (op, target, std::move (patterns)))) {
158+ populateUnsignedWhenEquivalentPatterns (patterns, solver);
159+
160+ GreedyRewriteConfig config;
161+ config.listener = &listener;
162+
163+ if (failed (applyPatternsAndFoldGreedily (op, std::move (patterns), config)))
149164 signalPassFailure ();
150- }
151165 }
152166};
153167} // end anonymous namespace
154168
169+ void mlir::arith::populateUnsignedWhenEquivalentPatterns (
170+ RewritePatternSet &patterns, DataFlowSolver &solver) {
171+ patterns.add <ConvertOpToUnsigned<DivSIOp, DivUIOp>,
172+ ConvertOpToUnsigned<CeilDivSIOp, CeilDivUIOp>,
173+ ConvertOpToUnsigned<FloorDivSIOp, DivUIOp>,
174+ ConvertOpToUnsigned<RemSIOp, RemUIOp>,
175+ ConvertOpToUnsigned<MinSIOp, MinUIOp>,
176+ ConvertOpToUnsigned<MaxSIOp, MaxUIOp>,
177+ ConvertOpToUnsigned<ExtSIOp, ExtUIOp>, ConvertCmpIToUnsigned>(
178+ patterns.getContext (), solver);
179+ }
180+
155181std::unique_ptr<Pass> mlir::arith::createArithUnsignedWhenEquivalentPass () {
156182 return std::make_unique<ArithUnsignedWhenEquivalentPass>();
157183}
0 commit comments