@@ -118,6 +118,32 @@ static OpFoldResult foldBinaryOpChecked(
118118 return IntegerAttr::get (IndexType::get (lhs.getContext ()), *result64);
119119}
120120
121+ // / Helper for associative and commutative binary ops that can be transformed:
122+ // / `x = op(v, c1); y = op(x, c2)` -> `tmp = op(c1, c2); y = op(v, tmp)`
123+ // / where c1 and c2 are constants. It is expected that `tmp` will be folded.
124+ template <typename BinaryOp>
125+ LogicalResult
126+ canonicalizeAssociativeCommutativeBinaryOp (BinaryOp op,
127+ PatternRewriter &rewriter) {
128+ if (!mlir::matchPattern (op.getRhs (), mlir::m_Constant ()))
129+ return rewriter.notifyMatchFailure (op.getLoc (), " RHS is not a constant" );
130+
131+ auto lhsOp = op.getLhs ().template getDefiningOp <BinaryOp>();
132+ if (!lhsOp)
133+ return rewriter.notifyMatchFailure (op.getLoc (), " LHS is not the same BinaryOp" );
134+
135+ if (!mlir::matchPattern (lhsOp.getRhs (), mlir::m_Constant ()))
136+ return rewriter.notifyMatchFailure (op.getLoc (), " RHS of LHS op is not a constant" );
137+
138+ Value c = rewriter.createOrFold <BinaryOp>(op->getLoc (), op.getRhs (),
139+ lhsOp.getRhs ());
140+ if (c.getDefiningOp <BinaryOp>())
141+ return rewriter.notifyMatchFailure (op.getLoc (), " new BinaryOp was not folded" );
142+
143+ rewriter.replaceOpWithNewOp <BinaryOp>(op, lhsOp.getLhs (), c);
144+ return success ();
145+ }
146+
121147// ===----------------------------------------------------------------------===//
122148// AddOp
123149// ===----------------------------------------------------------------------===//
@@ -136,27 +162,9 @@ OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
136162
137163 return {};
138164}
139- // / Canonicalize
140- // / ` x = v + c1; y = x + c2` to `x = v + (c1 + c2)`
141- LogicalResult AddOp::canonicalize (AddOp op, PatternRewriter &rewriter) {
142- IntegerAttr c1, c2;
143- if (!mlir::matchPattern (op.getRhs (), mlir::m_Constant (&c1)))
144- return rewriter.notifyMatchFailure (op.getLoc (), " RHS is not a constant" );
145-
146- auto add = op.getLhs ().getDefiningOp <mlir::index::AddOp>();
147- if (!add)
148- return rewriter.notifyMatchFailure (op.getLoc (), " LHS is not a add" );
149-
150- if (!mlir::matchPattern (add.getRhs (), mlir::m_Constant (&c2)))
151- return rewriter.notifyMatchFailure (op.getLoc (), " RHS is not a constant" );
152-
153- auto c = rewriter.create <mlir::index::ConstantOp>(op->getLoc (),
154- c1.getInt () + c2.getInt ());
155- auto newAdd =
156- rewriter.create <mlir::index::AddOp>(op->getLoc (), add.getLhs (), c);
157165
158- rewriter. replaceOp ( op, newAdd);
159- return success ( );
166+ LogicalResult AddOp::canonicalize (AddOp op, PatternRewriter &rewriter) {
167+ return canonicalizeAssociativeCommutativeBinaryOp (op, rewriter );
160168}
161169
162170// ===----------------------------------------------------------------------===//
@@ -200,6 +208,10 @@ OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
200208 return {};
201209}
202210
211+ LogicalResult MulOp::canonicalize (MulOp op, PatternRewriter &rewriter) {
212+ return canonicalizeAssociativeCommutativeBinaryOp (op, rewriter);
213+ }
214+
203215// ===----------------------------------------------------------------------===//
204216// DivSOp
205217// ===----------------------------------------------------------------------===//
@@ -352,6 +364,10 @@ OpFoldResult MaxSOp::fold(FoldAdaptor adaptor) {
352364 });
353365}
354366
367+ LogicalResult MaxSOp::canonicalize (MaxSOp op, PatternRewriter &rewriter) {
368+ return canonicalizeAssociativeCommutativeBinaryOp (op, rewriter);
369+ }
370+
355371// ===----------------------------------------------------------------------===//
356372// MaxUOp
357373// ===----------------------------------------------------------------------===//
@@ -363,6 +379,10 @@ OpFoldResult MaxUOp::fold(FoldAdaptor adaptor) {
363379 });
364380}
365381
382+ LogicalResult MaxUOp::canonicalize (MaxUOp op, PatternRewriter &rewriter) {
383+ return canonicalizeAssociativeCommutativeBinaryOp (op, rewriter);
384+ }
385+
366386// ===----------------------------------------------------------------------===//
367387// MinSOp
368388// ===----------------------------------------------------------------------===//
@@ -374,6 +394,10 @@ OpFoldResult MinSOp::fold(FoldAdaptor adaptor) {
374394 });
375395}
376396
397+ LogicalResult MinSOp::canonicalize (MinSOp op, PatternRewriter &rewriter) {
398+ return canonicalizeAssociativeCommutativeBinaryOp (op, rewriter);
399+ }
400+
377401// ===----------------------------------------------------------------------===//
378402// MinUOp
379403// ===----------------------------------------------------------------------===//
@@ -385,6 +409,10 @@ OpFoldResult MinUOp::fold(FoldAdaptor adaptor) {
385409 });
386410}
387411
412+ LogicalResult MinUOp::canonicalize (MinUOp op, PatternRewriter &rewriter) {
413+ return canonicalizeAssociativeCommutativeBinaryOp (op, rewriter);
414+ }
415+
388416// ===----------------------------------------------------------------------===//
389417// ShlOp
390418// ===----------------------------------------------------------------------===//
@@ -442,6 +470,10 @@ OpFoldResult AndOp::fold(FoldAdaptor adaptor) {
442470 [](const APInt &lhs, const APInt &rhs) { return lhs & rhs; });
443471}
444472
473+ LogicalResult AndOp::canonicalize (AndOp op, PatternRewriter &rewriter) {
474+ return canonicalizeAssociativeCommutativeBinaryOp (op, rewriter);
475+ }
476+
445477// ===----------------------------------------------------------------------===//
446478// OrOp
447479// ===----------------------------------------------------------------------===//
@@ -452,6 +484,10 @@ OpFoldResult OrOp::fold(FoldAdaptor adaptor) {
452484 [](const APInt &lhs, const APInt &rhs) { return lhs | rhs; });
453485}
454486
487+ LogicalResult OrOp::canonicalize (OrOp op, PatternRewriter &rewriter) {
488+ return canonicalizeAssociativeCommutativeBinaryOp (op, rewriter);
489+ }
490+
455491// ===----------------------------------------------------------------------===//
456492// XOrOp
457493// ===----------------------------------------------------------------------===//
@@ -462,6 +498,10 @@ OpFoldResult XOrOp::fold(FoldAdaptor adaptor) {
462498 [](const APInt &lhs, const APInt &rhs) { return lhs ^ rhs; });
463499}
464500
501+ LogicalResult XOrOp::canonicalize (XOrOp op, PatternRewriter &rewriter) {
502+ return canonicalizeAssociativeCommutativeBinaryOp (op, rewriter);
503+ }
504+
465505// ===----------------------------------------------------------------------===//
466506// CastSOp
467507// ===----------------------------------------------------------------------===//
0 commit comments