@@ -138,39 +138,22 @@ OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
138138}
139139// / Canonicalize
140140// / ` x = v + c1; y = x + c2` to `x = v + (c1 + c2)`
141- // / ` x = v + c1; y = c2 + x` to `x = v + (c1 + c2)`
142- // / ` x = c1 + v; y = x + c2` to `x = v + (c1 + c2)`
143- // / ` x = c1 + v; y = c2 + x` to `x = v + (c1 + c2)`
144141LogicalResult AddOp::canonicalize (AddOp op, PatternRewriter &rewriter) {
145-
146- auto matchConstant = [](mlir::index::AddOp op, Value &v, IntegerAttr &c) {
147- v = op.getLhs ();
148- if (!mlir::matchPattern (op.getRhs (), mlir::m_Constant (&c))) {
149- v = op.getRhs ();
150- if (!mlir::matchPattern (op.getLhs (), mlir::m_Constant (&c)))
151- return false ;
152- }
153- return true ;
154- };
155-
156142 IntegerAttr c1, c2;
157- Value v1, v2;
143+ if (!mlir::matchPattern (op.getRhs (), mlir::m_Constant (&c1)))
144+ return rewriter.notifyMatchFailure (op.getLoc (), " RHS is not a constant" );
158145
159- if (!matchConstant (op, v1, c1))
160- return rewriter.notifyMatchFailure (op.getLoc (),
161- " neither LHS nor RHS is constant" );
162-
163- auto add = v1.getDefiningOp <mlir::index::AddOp>();
146+ auto add = op.getLhs ().getDefiningOp <mlir::index::AddOp>();
164147 if (!add)
165148 return rewriter.notifyMatchFailure (op.getLoc (), " LHS is not a add" );
166149
167- if (!matchConstant (add, v2, c2))
168- return rewriter.notifyMatchFailure (op.getLoc (),
169- " neither LHS nor RHS is constant" );
150+ if (!mlir::matchPattern (add.getRhs (), mlir::m_Constant (&c2)))
151+ return rewriter.notifyMatchFailure (op.getLoc (), " RHS is not a constant" );
170152
171153 auto c = rewriter.create <mlir::index::ConstantOp>(op->getLoc (),
172154 c1.getInt () + c2.getInt ());
173- auto newAdd = rewriter.create <mlir::index::AddOp>(op->getLoc (), v2, c);
155+ auto newAdd =
156+ rewriter.create <mlir::index::AddOp>(op->getLoc (), add.getLhs (), c);
174157
175158 rewriter.replaceOp (op, newAdd);
176159 return success ();
0 commit comments