Skip to content

Commit b00482f

Browse files
kumasentoivanradanov
authored andcommitted
[FoldSCFIf] scalrep before folding
[FoldSCFIf] scalrep before folding [FoldSCFIf] scalrep before folding
1 parent a110dd3 commit b00482f

File tree

5 files changed

+28
-65
lines changed

5 files changed

+28
-65
lines changed

tools/polymer/lib/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ add_mlir_conversion_library(PolymerTransforms
2424
MLIRStandard
2525
MLIRSupport
2626
MLIRAffineToStandard
27+
MLIRAffineTransforms
2728

2829
PolymerSupport
2930
PolymerTargetOpenScop

tools/polymer/lib/Transforms/FoldSCFIf.cc

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include "mlir/Analysis/Utils.h"
99
#include "mlir/Dialect/Affine/IR/AffineOps.h"
1010
#include "mlir/Dialect/Affine/IR/AffineValueMap.h"
11+
#include "mlir/Dialect/Affine/Passes.h"
1112
#include "mlir/Dialect/MemRef/IR/MemRef.h"
1213
#include "mlir/Dialect/SCF/SCF.h"
1314
#include "mlir/IR/BlockAndValueMapping.h"
@@ -125,6 +126,8 @@ struct MatchIfElsePass : PassWrapper<MatchIfElsePass, OperationPass<FuncOp>> {
125126

126127
LLVM_DEBUG(dbgs() << "Matched else block:\n" << ifOp << "\n\n");
127128
});
129+
130+
LLVM_DEBUG(dbgs() << "After store matched:\n" << f << "\n\n");
128131
}
129132
};
130133
} // namespace
@@ -290,20 +293,22 @@ struct LiftStoreOps : PassWrapper<LiftStoreOps, OperationPass<FuncOp>> {
290293
// branch.
291294
while (processLiftStoreOps(f, b))
292295
;
296+
297+
LLVM_DEBUG(dbgs() << "After LiftStoreOps: " << f << "\n\n");
293298
}
294299
};
295300
} // namespace
296301

297302
/// ---------------------- FoldSCFIf ----------------------------------
298303

299-
static void foldSCFIf(scf::IfOp ifOp, FuncOp f, OpBuilder &b) {
304+
static bool foldSCFIf(scf::IfOp ifOp, FuncOp f, OpBuilder &b) {
300305
Location loc = ifOp.getLoc();
301306

302307
LLVM_DEBUG(dbgs() << "Working on ifOp: " << ifOp << "\n\n");
303308

304309
if (!hasSingleStore(ifOp.thenBlock()) ||
305310
(ifOp.elseBlock() && !hasSingleStore(ifOp.elseBlock())))
306-
return;
311+
return false;
307312

308313
OpBuilder::InsertionGuard g(b);
309314
b.setInsertionPointAfter(ifOp);
@@ -337,6 +342,7 @@ static void foldSCFIf(scf::IfOp ifOp, FuncOp f, OpBuilder &b) {
337342
}
338343

339344
ifOp.erase();
345+
return true;
340346
}
341347

342348
/// Return true if anything changed.
@@ -345,8 +351,11 @@ static bool process(FuncOp f, OpBuilder &b) {
345351

346352
f.walk([&](scf::IfOp ifOp) {
347353
/// TODO: add verification.
348-
foldSCFIf(ifOp, f, b);
349-
changed = true;
354+
if (changed)
355+
return;
356+
357+
changed = foldSCFIf(ifOp, f, b);
358+
;
350359
});
351360

352361
return changed;
@@ -371,6 +380,7 @@ void polymer::registerFoldSCFIfPass() {
371380
[](OpPassManager &pm) {
372381
pm.addPass(std::make_unique<MatchIfElsePass>());
373382
pm.addPass(std::make_unique<LiftStoreOps>());
383+
pm.addPass(createAffineScalarReplacementPass());
374384
pm.addPass(std::make_unique<FoldSCFIfPass>());
375385
});
376386
}

tools/polymer/test/polymer-opt/Application/aes.mlir

Lines changed: 10 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -51,64 +51,13 @@ func @encrypt(%Sbox: memref<?x16xi32>, %statemt: memref<?xi32>) {
5151
return
5252
}
5353

54-
// CHECK: func private @S0(%[[statemt:.*]]: memref<?xi32>, %[[i:.*]]: index, %[[Sbox:.*]]: memref<?x16xi32>) attributes {scop.stmt}
55-
// CHECK-NEXT: %[[c15_i32:.*]] = arith.constant 15 : i32
56-
// CHECK-NEXT: %[[c4_i32:.*]] = arith.constant 4 : i32
57-
// CHECK-NEXT: %[[v0:.*]] = affine.load %[[statemt]][symbol(%[[i]]) * 4] : memref<?xi32>
58-
// CHECK-NEXT: %[[v1:.*]] = arith.shrsi %[[v0]], %[[c4_i32]] : i32
59-
// CHECK-NEXT: %[[v2:.*]] = arith.index_cast %[[v1]] : i32 to index
60-
// CHECK-NEXT: %[[v3:.*]] = arith.andi %[[v0]], %[[c15_i32]] : i32
61-
// CHECK-NEXT: %[[v4:.*]] = arith.index_cast %[[v3]] : i32 to index
62-
// CHECK-NEXT: %[[v5:.*]] = memref.load %[[Sbox]][%[[v2]], %[[v4]]] : memref<?x16xi32>
63-
// CHECK-NEXT: affine.store %[[v5]], %[[statemt]][symbol(%[[i]]) * 4] : memref<?xi32>
64-
65-
// CHECK: func private @S1(%[[i:.*]]: index, %[[ret:.*]]: memref<1024xi32>, %[[statemt:.*]]: memref<?xi32>) attributes {scop.stmt}
66-
// CHECK-NEXT: %[[c1_i32:.*]] = arith.constant 1 : i32
67-
// CHECK-NEXT: %[[v0:.*]] = affine.load %[[statemt]][symbol(%[[i]])] : memref<?xi32>
68-
// CHECK-NEXT: %[[v1:.*]] = arith.shli %[[v0]], %[[c1_i32]] : i32
69-
// CHECK-NEXT: affine.store %[[v1]], %[[ret]][symbol(%[[i]])] : memref<1024xi32>
70-
71-
// CHECK: func private @S2(%[[i:.*]]: index, %[[ret:.*]]: memref<1024xi32>) attributes {scop.stmt}
72-
// CHECK: %[[c283_i32:.*]] = arith.constant 283 : i32
73-
// CHECK: %[[c1_i32:.*]] = arith.constant 1 : i32
74-
// CHECK: %[[c8_i32:.*]] = arith.constant 8 : i32
75-
// CHECK: %[[v0:.*]] = affine.load %[[ret]][symbol(%[[i]])] : memref<1024xi32>
76-
// CHECK: %[[v1:.*]] = arith.shrsi %[[v0]], %[[c8_i32]] : i32
77-
// CHECK: %[[v2:.*]] = arith.cmpi eq, %[[v1]], %[[c1_i32]] : i32
78-
// CHECK: %[[v3:.*]] = arith.xori %[[v0]], %[[c283_i32]] : i32
79-
// CHECK: %[[v4:.*]] = affine.load %[[ret]][symbol(%[[i]])] : memref<1024xi32>
80-
// CHECK: %[[v5:.*]] = select %[[v2]], %[[v3]], %[[v4]] : i32
81-
// CHECK: affine.store %[[v5]], %[[ret]][symbol(%[[i]])] : memref<1024xi32>
82-
83-
// CHECK: func private @S3(%[[i:.*]]: index, %[[ret:.*]]: memref<1024xi32>, %[[statemt]]: memref<?xi32>) attributes {scop.stmt}
84-
// CHECK-NEXT: %[[c283_i32:.*]] = arith.constant 283 : i32
85-
// CHECK-NEXT: %[[c1_i32:.*]] = arith.constant 1 : i32
86-
// CHECK-NEXT: %[[c8_i32:.*]] = arith.constant 8 : i32
87-
// CHECK-NEXT: %[[v0:.*]] = affine.load %[[ret]][symbol(%[[i]])] : memref<1024xi32>
88-
// CHECK-NEXT: %[[v1:.*]] = affine.load %[[ret]][symbol(%[[i]])] : memref<1024xi32>
89-
// CHECK-NEXT: %[[v2:.*]] = affine.load %[[statemt]][symbol(%[[i]]) + 1] : memref<?xi32>
90-
// CHECK-NEXT: %[[v3:.*]] = arith.shli %[[v2]], %[[c1_i32]] : i32
91-
// CHECK-NEXT: %[[v4:.*]] = arith.xori %[[v2]], %[[v3]] : i32
92-
// CHECK-NEXT: %[[v5:.*]] = arith.shrsi %[[v4]], %[[c8_i32]] : i32
93-
// CHECK-NEXT: %[[v6:.*]] = arith.cmpi eq, %[[v5]], %[[c1_i32]] : i32
94-
// CHECK-NEXT: %[[v7:.*]] = arith.xori %[[v4]], %[[c283_i32]] : i32
95-
// CHECK-NEXT: %[[v8:.*]] = arith.xori %[[v0]], %[[v7]] : i32
96-
// CHECK-NEXT: %[[v9:.*]] = arith.xori %[[v1]], %[[v4]] : i32
97-
// CHECK-NEXT: %[[v10:.*]] = select %[[v6]], %[[v8]], %[[v9]] : i32
98-
// CHECK-NEXT: affine.store %[[v10]], %[[ret]][symbol(%[[i]])] : memref<1024xi32>
99-
100-
// CHECK: func private @S4(%[[statemt:.*]]: memref<?xi32>, %[[i:.*]]: index, %[[ret:.*]]: memref<1024xi32>) attributes {scop.stmt}
101-
// CHECK-NEXT: %[[v0:.*]] = affine.load %[[ret]][symbol(%[[i]])] : memref<1024xi32>
102-
// CHECK-NEXT: affine.store %[[v0]], %[[statemt]][symbol(%[[i]])] : memref<?xi32>
103-
104-
// CHECK: func @encrypt(%[[Sbox:.*]]: memref<?x16xi32>, %[[statemt:.*]]: memref<?xi32>)
105-
// CHECK: %[[ret:.*]] = memref.alloca() : memref<1024xi32>
106-
// CHECK: affine.for %[[i:.*]] = 1 to 5
107-
// CHECK: affine.for %[[j:.*]] = 0 to 16
108-
// CHECK: call @S0(%[[statemt]], %[[j]], %[[Sbox]]) : (memref<?xi32>, index, memref<?x16xi32>) -> ()
109-
// CHECK: affine.for %[[j:.*]] = 0 to 1023
110-
// CHECK: call @S1(%[[j]], %[[ret]], %[[statemt]]) : (index, memref<1024xi32>, memref<?xi32>) -> ()
111-
// CHECK: call @S2(%[[j]], %[[ret]]) : (index, memref<1024xi32>) -> ()
112-
// CHECK: call @S3(%[[j]], %[[ret]], %[[statemt]]) : (index, memref<1024xi32>, memref<?xi32>) -> ()
113-
// CHECK: affine.for %[[j:.*]] = 0 to 1024
114-
// CHECK: call @S4(%[[statemt]], %[[j]], %[[ret]]) : (memref<?xi32>, index, memref<1024xi32>) -> ()
54+
// CHECK: func @encrypt(%[[Sbox:.*]]: memref<?x16xi32>, %[[statemt:.*]]: memref<?xi32>)
55+
// CHECK: %[[v0:.*]] = memref.alloca() : memref<1024xi32>
56+
// CHECK: affine.for %[[i:.*]] = 1 to 5
57+
// CHECK: affine.for %[[j:.*]] = 0 to 16
58+
// CHECK: call @S0(%[[statemt]], %[[j]], %[[Sbox]])
59+
// CHECK: affine.for %[[j:.*]] = 0 to 1023
60+
// CHECK: call @S1(%[[j]], %[[v0]], %[[statemt]])
61+
// CHECK: call @S2(%[[j]], %[[v0]], %[[statemt]])
62+
// CHECK: affine.for %[[j:.*]] = 0 to 1024
63+
// CHECK: call @S3(%[[statemt]], %[[j]], %[[v0]])

tools/polymer/tools/polymer-opt/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ target_link_libraries(polymer-opt
2121
MLIRTransformUtils
2222
MLIRSupport
2323
MLIRIR
24+
MLIRAffineTransforms
2425

2526
PolymerTransforms
2627
)

tools/polymer/tools/polymer-opt/polymer-opt.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "polymer/Transforms/ScopStmtOpt.h"
1616

1717
#include "mlir/Dialect/Affine/IR/AffineOps.h"
18+
#include "mlir/Dialect/Affine/Passes.h"
1819
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
1920
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
2021
#include "mlir/Dialect/Math/IR/Math.h"
@@ -54,6 +55,7 @@ int main(int argc, char *argv[]) {
5455
registerCanonicalizerPass();
5556
registerCSEPass();
5657
registerInlinerPass();
58+
registerAffineScalarReplacementPass();
5759
// Register polymer specific passes.
5860
registerPlutoTransformPass();
5961
registerRegToMemPass();

0 commit comments

Comments
 (0)