Skip to content

Commit ae17efd

Browse files
add baisc implement of hoist pure pass.
1 parent ca50a53 commit ae17efd

File tree

3 files changed

+91
-0
lines changed

3 files changed

+91
-0
lines changed

mlir/include/mlir/Transforms/Passes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ class GreedyRewriteConfig;
3737
#define GEN_PASS_DECL_COMPOSITEFIXEDPOINTPASS
3838
#define GEN_PASS_DECL_CONTROLFLOWSINK
3939
#define GEN_PASS_DECL_GENERATERUNTIMEVERIFICATION
40+
#define GEN_PASS_DECL_HOISTPUREOPS
4041
#define GEN_PASS_DECL_LOOPINVARIANTCODEMOTION
4142
#define GEN_PASS_DECL_INLINER
4243
#define GEN_PASS_DECL_MEM2REG

mlir/lib/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ add_mlir_library(MLIRTransforms
2121
SymbolPrivatize.cpp
2222
TopologicalSort.cpp
2323
ViewOpGraph.cpp
24+
HoistPureOps.cpp
2425

2526
ADDITIONAL_HEADER_DIRS
2627
${MLIR_MAIN_INCLUDE_DIR}/mlir/Transforms
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
//===- HoistPureOps.cpp - Hoist Pure ops ----------------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file implements the function of hoist the pure op based on SSA
10+
// dominance.
11+
//
12+
//===----------------------------------------------------------------------===//
13+
14+
#include "mlir/IR/Dialect.h"
15+
#include "mlir/IR/Dominance.h"
16+
#include "mlir/IR/Operation.h"
17+
#include "mlir/Pass/Pass.h"
18+
#include "mlir/Transforms/Passes.h"
19+
20+
namespace mlir {
21+
#define GEN_PASS_DEF_HOISTPUREOPS
22+
#include "mlir/Transforms/Passes.h.inc"
23+
} // namespace mlir
24+
25+
using namespace mlir;
26+
27+
namespace {
28+
29+
/// Return the dominated Value.
30+
static Value getDomaincedValue(DominanceInfo &dominanceInfo, Value a, Value b) {
31+
Block *aB = a.getParentBlock();
32+
Block *bB = b.getParentBlock();
33+
if (isa_and_present<BlockArgument>(a) && isa_and_present<BlockArgument>(b)) {
34+
return dominanceInfo.dominates(aB, bB) ? b : a;
35+
} else if (isa_and_present<BlockArgument>(a) ||
36+
isa_and_present<BlockArgument>(b)) {
37+
if (aB == bB)
38+
return b;
39+
return dominanceInfo.dominates(aB, bB) ? b : a;
40+
} else {
41+
Operation *aDefineOp = a.getDefiningOp();
42+
Operation *bDefineOp = b.getDefiningOp();
43+
return dominanceInfo.dominates(aDefineOp, bDefineOp) ? b : a;
44+
}
45+
}
46+
47+
/// Find the hoisting position for the pure op.
48+
static Value getDestPos(Operation *op) {
49+
DominanceInfo dominanceInfo(op);
50+
SmallVector<Value> operands(op->getOperands());
51+
if (operands.empty())
52+
return {};
53+
Value ret = operands[0];
54+
for (int i = 1, e = operands.size(); i < e; ++i) {
55+
ret = getDomaincedValue(dominanceInfo, ret, operands[i]);
56+
}
57+
return ret;
58+
}
59+
60+
/// Hoist single pure op.
61+
static void hoistPureOp(RewriterBase &rewriter, Operation *op) {
62+
Value pos = getDestPos(op);
63+
if (!pos)
64+
return;
65+
66+
if (Operation *defineOp = pos.getDefiningOp()) {
67+
rewriter.moveOpAfter(op, defineOp);
68+
return;
69+
}
70+
auto argument = cast<BlockArgument>(pos);
71+
rewriter.moveOpBefore(op, &argument.getOwner()->front());
72+
}
73+
74+
struct HoistPureOps : public impl::HoistPureOpsBase<HoistPureOps> {
75+
void runOnOperation() override;
76+
};
77+
} // namespace
78+
79+
void HoistPureOps::runOnOperation() {
80+
Operation *module = getOperation();
81+
IRRewriter rewriter(module->getContext());
82+
module->walk<WalkOrder::PreOrder>([&](Operation *op) {
83+
if (op->hasTrait<mlir::OpTrait::IsTerminator>())
84+
return;
85+
if (isPure(op)) {
86+
hoistPureOp(rewriter, op);
87+
}
88+
});
89+
}

0 commit comments

Comments
 (0)