Skip to content

Commit 6e35949

Browse files
kumasentoivanradanov
authored andcommitted
[FoldSCFIf] added -fold-scf-if pass
1 parent 1830896 commit 6e35949

File tree

6 files changed

+152
-0
lines changed

6 files changed

+152
-0
lines changed
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
//===- FoldSCFIf.h - Fold scf.if into select --------------C++-===//
2+
3+
#ifndef POLYMER_TRANSFORMS_FOLDSCFIF_H
4+
#define POLYMER_TRANSFORMS_FOLDSCFIF_H
5+
6+
namespace polymer {
7+
void registerFoldSCFIfPass();
8+
}
9+
10+
#endif

tools/polymer/lib/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ add_mlir_conversion_library(PolymerTransforms
55
ScopStmtOpt.cc
66
LoopAnnotate.cc
77
LoopExtract.cc
8+
FoldSCFIf.cc
89

910
ADDITIONAL_HEADER_DIRS
1011
"${POLYMER_MAIN_INCLUDE_DIR}/polymer/Transforms"
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
//===- FoldSCFIf.cc - Fold scf.if into select --------------C++-===//
2+
3+
#include "polymer/Transforms/FoldSCFIf.h"
4+
5+
#include "mlir/Analysis/AffineAnalysis.h"
6+
#include "mlir/Analysis/AffineStructures.h"
7+
#include "mlir/Analysis/SliceAnalysis.h"
8+
#include "mlir/Analysis/Utils.h"
9+
#include "mlir/Dialect/Affine/IR/AffineOps.h"
10+
#include "mlir/Dialect/Affine/IR/AffineValueMap.h"
11+
#include "mlir/Dialect/SCF/SCF.h"
12+
#include "mlir/IR/BlockAndValueMapping.h"
13+
#include "mlir/IR/Builders.h"
14+
#include "mlir/IR/Dominance.h"
15+
#include "mlir/IR/OpImplementation.h"
16+
#include "mlir/IR/PatternMatch.h"
17+
#include "mlir/IR/Types.h"
18+
#include "mlir/IR/Value.h"
19+
#include "mlir/Pass/Pass.h"
20+
#include "mlir/Pass/PassManager.h"
21+
#include "mlir/Transforms/DialectConversion.h"
22+
#include "mlir/Transforms/Passes.h"
23+
#include "mlir/Transforms/RegionUtils.h"
24+
#include "mlir/Transforms/Utils.h"
25+
26+
#include "llvm/Support/Debug.h"
27+
28+
using namespace mlir;
29+
using namespace llvm;
30+
31+
#define DEBUG_TYPE "fold-scf-if"
32+
33+
static void foldSCFIf(scf::IfOp ifOp, FuncOp f, OpBuilder &b) {
34+
Location loc = ifOp.getLoc();
35+
36+
LLVM_DEBUG(dbgs() << "Working on ifOp: " << ifOp << "\n\n");
37+
38+
OpBuilder::InsertionGuard g(b);
39+
b.setInsertionPointAfter(ifOp);
40+
41+
SmallVector<Value> thenResults, elseResults;
42+
43+
auto cloneAfter = [&](Block *block, SmallVectorImpl<Value> &results) {
44+
BlockAndValueMapping vmap;
45+
for (Operation &op : block->getOperations()) {
46+
if (auto yieldOp = dyn_cast<scf::YieldOp>(op))
47+
for (Value result : yieldOp.getOperands())
48+
results.push_back(vmap.contains(result) ? vmap.lookup(result)
49+
: result);
50+
else
51+
b.clone(op, vmap);
52+
}
53+
};
54+
55+
cloneAfter(ifOp.thenBlock(), thenResults);
56+
cloneAfter(ifOp.elseBlock(), elseResults);
57+
58+
for (auto ifResult : enumerate(ifOp.getResults())) {
59+
Value newResult =
60+
b.create<SelectOp>(loc, ifOp.condition(), thenResults[ifResult.index()],
61+
elseResults[ifResult.index()]);
62+
ifResult.value().replaceAllUsesWith(newResult);
63+
}
64+
65+
ifOp.erase();
66+
}
67+
68+
/// Return true if anything changed.
69+
static bool process(FuncOp f, OpBuilder &b) {
70+
bool changed = false;
71+
72+
f.walk([&](scf::IfOp ifOp) {
73+
/// TODO: add verification.
74+
foldSCFIf(ifOp, f, b);
75+
changed = true;
76+
});
77+
78+
return changed;
79+
}
80+
81+
namespace {
82+
struct FoldSCFIfPass : PassWrapper<FoldSCFIfPass, OperationPass<FuncOp>> {
83+
void runOnOperation() override {
84+
FuncOp f = getOperation();
85+
OpBuilder b(f.getContext());
86+
87+
while (process(f, b))
88+
;
89+
}
90+
};
91+
} // namespace
92+
93+
void polymer::registerFoldSCFIfPass() {
94+
PassPipelineRegistration<>(
95+
"fold-scf-if", "Fold scf.if into select.",
96+
[](OpPassManager &pm) { pm.addPass(std::make_unique<FoldSCFIfPass>()); });
97+
}
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
// RUN: polymer-opt %s -fold-scf-if | FileCheck %s
2+
3+
func @foo(%a: f32, %b: f32, %c: i1, %d: i1) -> f32 {
4+
%0 = scf.if %c -> f32 {
5+
%0 = scf.if %d -> f32 {
6+
scf.yield %a : f32
7+
} else {
8+
scf.yield %b : f32
9+
}
10+
%1 = arith.addf %0, %b : f32
11+
scf.yield %1 : f32
12+
} else {
13+
%1 = arith.mulf %a, %b : f32
14+
scf.yield %1 : f32
15+
}
16+
return %0 : f32
17+
}
18+
19+
// CHECK: func @foo(%[[a:.*]]: f32, %[[b:.*]]: f32, %[[c:.*]]: i1, %[[d:.*]]: i1) -> f32
20+
// CHECK-NEXT: %[[v0:.*]] = select %[[d]], %[[a]], %[[b]] : f32
21+
// CHECK-NEXT: %[[v1:.*]] = arith.addf %[[v0]], %[[b]] : f32
22+
// CHECK-NEXT: %[[v2:.*]] = arith.mulf %[[a]], %[[b]] : f32
23+
// CHECK-NEXT: %[[v3:.*]] = select %[[c]], %[[v1]], %[[v2]] : f32
24+
// CHECK-NEXT: return %[[v3]] : f32
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
// RUN: polymer-opt %s -fold-scf-if | FileCheck %s
2+
3+
func @foo(%a: f32, %b: f32, %c: i1) -> f32 {
4+
%0 = scf.if %c -> f32 {
5+
%1 = arith.addf %a, %b : f32
6+
scf.yield %1 : f32
7+
} else {
8+
%1 = arith.mulf %a, %b : f32
9+
scf.yield %1 : f32
10+
}
11+
return %0 : f32
12+
}
13+
14+
// CHECK: func @foo(%[[a:.*]]: f32, %[[b:.*]]: f32, %[[c:.*]]: i1) -> f32
15+
// CHECK-NEXT: %[[v0:.*]] = arith.addf %[[a]], %[[b]] : f32
16+
// CHECK-NEXT: %[[v1:.*]] = arith.mulf %[[a]], %[[b]] : f32
17+
// CHECK-NEXT: %[[v2:.*]] = select %[[c]], %[[v0]], %[[v1]] : f32
18+
// CHECK-NEXT: return %[[v2]] : f32

tools/polymer/tools/polymer-opt/polymer-opt.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
//===----------------------------------------------------------------------===//
77

88
#include "polymer/Transforms/ExtractScopStmt.h"
9+
#include "polymer/Transforms/FoldSCFIf.h"
910
#include "polymer/Transforms/LoopAnnotate.h"
1011
#include "polymer/Transforms/LoopExtract.h"
1112
#include "polymer/Transforms/PlutoTransform.h"
@@ -59,6 +60,7 @@ int main(int argc, char *argv[]) {
5960
registerScopStmtOptPasses();
6061
registerLoopAnnotatePasses();
6162
registerLoopExtractPasses();
63+
registerFoldSCFIfPass();
6264

6365
// Register any pass manager command line options.
6466
registerMLIRContextCLOptions();

0 commit comments

Comments
 (0)