Skip to content

Commit c755012

Browse files
committed
Address review comments.
1 parent d1725fa commit c755012

File tree

2 files changed

+11
-30
lines changed

2 files changed

+11
-30
lines changed

mlir/lib/Dialect/Index/IR/IndexOps.cpp

Lines changed: 7 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -138,39 +138,22 @@ OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
138138
}
139139
/// Canonicalize
140140
/// ` x = v + c1; y = x + c2` to `x = v + (c1 + c2)`
141-
/// ` x = v + c1; y = c2 + x` to `x = v + (c1 + c2)`
142-
/// ` x = c1 + v; y = x + c2` to `x = v + (c1 + c2)`
143-
/// ` x = c1 + v; y = c2 + x` to `x = v + (c1 + c2)`
144141
LogicalResult AddOp::canonicalize(AddOp op, PatternRewriter &rewriter) {
145-
146-
auto matchConstant = [](mlir::index::AddOp op, Value &v, IntegerAttr &c) {
147-
v = op.getLhs();
148-
if (!mlir::matchPattern(op.getRhs(), mlir::m_Constant(&c))) {
149-
v = op.getRhs();
150-
if (!mlir::matchPattern(op.getLhs(), mlir::m_Constant(&c)))
151-
return false;
152-
}
153-
return true;
154-
};
155-
156142
IntegerAttr c1, c2;
157-
Value v1, v2;
143+
if (!mlir::matchPattern(op.getRhs(), mlir::m_Constant(&c1)))
144+
return rewriter.notifyMatchFailure(op.getLoc(), "RHS is not a constant");
158145

159-
if (!matchConstant(op, v1, c1))
160-
return rewriter.notifyMatchFailure(op.getLoc(),
161-
"neither LHS nor RHS is constant");
162-
163-
auto add = v1.getDefiningOp<mlir::index::AddOp>();
146+
auto add = op.getLhs().getDefiningOp<mlir::index::AddOp>();
164147
if (!add)
165148
return rewriter.notifyMatchFailure(op.getLoc(), "LHS is not a add");
166149

167-
if (!matchConstant(add, v2, c2))
168-
return rewriter.notifyMatchFailure(op.getLoc(),
169-
"neither LHS nor RHS is constant");
150+
if (!mlir::matchPattern(add.getRhs(), mlir::m_Constant(&c2)))
151+
return rewriter.notifyMatchFailure(op.getLoc(), "RHS is not a constant");
170152

171153
auto c = rewriter.create<mlir::index::ConstantOp>(op->getLoc(),
172154
c1.getInt() + c2.getInt());
173-
auto newAdd = rewriter.create<mlir::index::AddOp>(op->getLoc(), v2, c);
155+
auto newAdd =
156+
rewriter.create<mlir::index::AddOp>(op->getLoc(), add.getLhs(), c);
174157

175158
rewriter.replaceOp(op, newAdd);
176159
return success();

mlir/test/Dialect/Index/index-canonicalize.mlir

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,14 +37,12 @@ func.func @add_fold_constants(%arg: index) -> (index) {
3737
%0 = index.constant 1
3838
%1 = index.constant 2
3939
%2 = index.add %arg, %0
40-
%3 = index.add %1, %2
41-
%4 = index.add %3, %1
42-
%5 = index.add %4, %0
40+
%3 = index.add %2, %1
4341

44-
// CHECK-DAG: [[A:%.*]] = index.constant 6
45-
// CHECK-DAG: [[V0:%.*]] = index.add %arg0, [[A]]
42+
// CHECK-DAG: [[C3:%.*]] = index.constant 3
43+
// CHECK-DAG: [[V0:%.*]] = index.add %arg0, [[C3]]
4644
// CHECK: return [[V0]]
47-
return %5 : index
45+
return %3 : index
4846
}
4947

5048
// CHECK-LABEL: @sub

0 commit comments

Comments
 (0)