Skip to content

Commit 683abad

Browse files
support region op.
1 parent 397448a commit 683abad

File tree

2 files changed

+31
-2
lines changed

2 files changed

+31
-2
lines changed

mlir/lib/Transforms/HoistPureOps.cpp

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,10 +56,34 @@ static Value getDomaincedValue(DominanceInfo &dominanceInfo, Value a, Value b) {
5656
}
5757
}
5858

59+
static bool isOpContainBlock(Operation *op, Block *block) {
60+
Operation *parentOp = block->getParentOp();
61+
while (parentOp && parentOp != op) {
62+
parentOp = parentOp->getParentOp();
63+
}
64+
return parentOp == op ? true : false;
65+
}
66+
5967
/// Find the hoisting position for the pure op.
6068
static Value getDestPos(Operation *op) {
6169
DominanceInfo dominanceInfo(op);
6270
SmallVector<Value> operands(op->getOperands());
71+
if (op->getNumRegions()) {
72+
op->walk([&](Operation *operation) {
73+
for (auto operand : operation->getOperands()) {
74+
Operation *defineOp = operand.getDefiningOp();
75+
if (!defineOp) {
76+
BlockArgument argument = cast<BlockArgument>(operand);
77+
if (!isOpContainBlock(op, argument.getOwner()))
78+
operands.push_back(operand);
79+
continue;
80+
}
81+
if (!isOpContainBlock(op, defineOp->getBlock())) {
82+
operands.push_back(operand);
83+
}
84+
}
85+
});
86+
}
6387
if (operands.empty())
6488
return {};
6589
Value ret = operands[0];
@@ -71,13 +95,18 @@ static Value getDestPos(Operation *op) {
7195

7296
/// Hoist single pure op.
7397
static void hoistPureOp(RewriterBase &rewriter, Operation *op) {
98+
LDBG() << "hoistPureOp: " << OpWithFlags(op, OpPrintingFlags().skipRegions());
7499
Value pos = getDestPos(op);
75100
if (!pos)
76101
return;
77102

78103
if (Operation *defineOp = pos.getDefiningOp()) {
104+
if (op == defineOp)
105+
return;
106+
79107
LDBG() << "move " << OpWithFlags(op, OpPrintingFlags().skipRegions())
80-
<< " after " << OpWithFlags(op, OpPrintingFlags().skipRegions());
108+
<< " after "
109+
<< OpWithFlags(defineOp, OpPrintingFlags().skipRegions());
81110
rewriter.moveOpAfter(op, defineOp);
82111
return;
83112
}

mlir/test/Transforms/loop-invariant-code-motion.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: mlir-opt %s -split-input-file -loop-invariant-code-motion | FileCheck %s
1+
// RUN: mlir-opt %s -split-input-file -hoist-pure-ops | FileCheck %s
22

33
func.func @nested_loops_both_having_invariant_code() {
44
%m = memref.alloc() : memref<10xf32>

0 commit comments

Comments
 (0)