Skip to content

Commit 93d5a1e

Browse files
authored
Fix Parallel LICM & inheritence (#196)
* Fix Parallel LICM condition * Fix ParallelLICM * Fix inheritence base lowering * Fix address merged offset of base * Add tests * Handle new virtual/pure only * Stabilize ptraddsubtest * Fix flag name * Override cudaGetLastError * Fix format
1 parent e7489c4 commit 93d5a1e

File tree

10 files changed

+391
-178
lines changed

10 files changed

+391
-178
lines changed

lib/polygeist/Passes/ParallelLICM.cpp

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -65,9 +65,25 @@ static bool canBeParallelHoisted(Operation *op, Operation *scope,
6565
freeResources.push_back(effect.getResource());
6666
}
6767

68-
auto conflicting = [&](Operation *b) {
68+
std::function<bool(Operation *)> conflicting = [&](Operation *b) {
6969
if (willBeMoved.count(b))
7070
return false;
71+
72+
if (b->hasTrait<OpTrait::HasRecursiveSideEffects>()) {
73+
74+
for (auto &region : b->getRegions()) {
75+
for (auto &block : region) {
76+
for (auto &innerOp : block)
77+
if (conflicting(&innerOp))
78+
return true;
79+
}
80+
}
81+
return false;
82+
}
83+
84+
auto memEffect = dyn_cast<MemoryEffectOpInterface>(b);
85+
if (!memEffect)
86+
return true;
7187
for (auto res : readResources) {
7288
SmallVector<MemoryEffects::EffectInstance> effects;
7389
memEffect.getEffectsOnResource(res, effects);
@@ -107,14 +123,14 @@ static bool canBeParallelHoisted(Operation *op, Operation *scope,
107123
for (Operation *it = b->getPrevNode(); it != nullptr;
108124
it = it->getPrevNode()) {
109125
if (conflicting(it)) {
110-
return false;
126+
return true;
111127
}
112128
}
113129

114130
if (b->getParentOp() == scope)
115131
return false;
116132
if (hasConflictBefore(b->getParentOp()))
117-
return false;
133+
return true;
118134

119135
bool conflict = false;
120136
// If the parent operation is not guaranteed to execute its (single-block)
@@ -192,7 +208,7 @@ LogicalResult moveParallelLoopInvariantCode(scf::ParallelOp looplike) {
192208
Value cond = nullptr;
193209
for (auto pair :
194210
llvm::zip(looplike.getLowerBound(), looplike.getUpperBound())) {
195-
auto val = b.create<arith::CmpIOp>(looplike.getLoc(), CmpIPredicate::sgt,
211+
auto val = b.create<arith::CmpIOp>(looplike.getLoc(), CmpIPredicate::slt,
196212
std::get<0>(pair), std::get<1>(pair));
197213
if (cond == nullptr)
198214
cond = val;

lib/polygeist/Passes/ParallelLower.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -484,6 +484,14 @@ void ParallelLower::runOnOperation() {
484484
Value vals[] = {retv};
485485
call.replaceAllUsesWith(ArrayRef<Value>(vals));
486486
call.erase();
487+
} else if (call.getCallee().getValue() == "cudaGetLastError") {
488+
OpBuilder bz(call);
489+
auto retv = bz.create<ConstantIntOp>(
490+
call.getLoc(), 0,
491+
call.getResult(0).getType().cast<IntegerType>().getWidth());
492+
Value vals[] = {retv};
493+
call.replaceAllUsesWith(ArrayRef<Value>(vals));
494+
call.erase();
487495
}
488496
});
489497
getOperation().walk([&](CallOp call) {
@@ -508,6 +516,14 @@ void ParallelLower::runOnOperation() {
508516
call.replaceAllUsesWith(
509517
bz.create<ConstantIntOp>(call.getLoc(), 0, call.getType(0)));
510518
call.erase();
519+
} else if (call.getCallee() == "cudaGetLastError") {
520+
OpBuilder bz(call);
521+
auto retv = bz.create<ConstantIntOp>(
522+
call.getLoc(), 0,
523+
call.getResult(0).getType().cast<IntegerType>().getWidth());
524+
Value vals[] = {retv};
525+
call.replaceAllUsesWith(ArrayRef<Value>(vals));
526+
call.erase();
511527
}
512528
});
513529

test/polygeist-opt/parallellicm.mlir

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
// RUN: polygeist-opt --parallel-licm --split-input-file %s | FileCheck %s
2+
3+
module {
4+
func private @use(f32)
5+
func @hoist(%arg0: memref<?xf32>, %arg1: index, %arg2: index) {
6+
%cst = arith.constant 0.000000e+00 : f32
7+
%c1 = arith.constant 1 : index
8+
%a = memref.alloca() : memref<f32>
9+
memref.store %cst, %a[] : memref<f32>
10+
scf.parallel (%arg3) = (%arg1) to (%arg2) step (%c1) {
11+
%v = memref.load %a[] : memref<f32>
12+
call @use(%v) : (f32) -> ()
13+
}
14+
return
15+
}
16+
func @hoist2(%arg0: memref<?xf32>, %arg1: index, %arg2: index) {
17+
%cst = arith.constant 0.000000e+00 : f32
18+
%c1 = arith.constant 1 : index
19+
%a = memref.alloca() : memref<f32>
20+
scf.parallel (%arg3) = (%arg1) to (%arg2) step (%c1) {
21+
memref.store %cst, %a[] : memref<f32>
22+
%v = memref.load %a[] : memref<f32>
23+
call @use(%v) : (f32) -> ()
24+
}
25+
return
26+
}
27+
func private @get() -> (f32)
28+
func @nohoist(%arg0: memref<?xf32>, %arg1: index, %arg2: index) {
29+
%c1 = arith.constant 1 : index
30+
%a = memref.alloca() : memref<f32>
31+
scf.parallel (%arg3) = (%arg1) to (%arg2) step (%c1) {
32+
%cst = call @get() : () -> (f32)
33+
memref.store %cst, %a[] : memref<f32>
34+
%v = memref.load %a[] : memref<f32>
35+
call @use(%v) : (f32) -> ()
36+
}
37+
return
38+
}
39+
}
40+
41+
// CHECK: func @hoist(%arg0: memref<?xf32>, %arg1: index, %arg2: index) {
42+
// CHECK-DAG: %cst = arith.constant 0.000000e+00 : f32
43+
// CHECK-DAG: %c1 = arith.constant 1 : index
44+
// CHECK-NEXT: %0 = memref.alloca() : memref<f32>
45+
// CHECK-NEXT: memref.store %cst, %0[] : memref<f32>
46+
// CHECK-NEXT: %1 = arith.cmpi slt, %arg1, %arg2 : index
47+
// CHECK-NEXT: scf.if %1 {
48+
// CHECK-NEXT: %2 = memref.load %0[] : memref<f32>
49+
// CHECK-NEXT: scf.parallel (%arg3) = (%arg1) to (%arg2) step (%c1) {
50+
// CHECK-NEXT: call @use(%2) : (f32) -> ()
51+
// CHECK-NEXT: scf.yield
52+
// CHECK-NEXT: }
53+
// CHECK-NEXT: }
54+
// CHECK-NEXT: return
55+
// CHECK-NEXT: }
56+
57+
// CHECK: func @hoist2(%arg0: memref<?xf32>, %arg1: index, %arg2: index) {
58+
// CHECK-DAG: %cst = arith.constant 0.000000e+00 : f32
59+
// CHECK-DAG: %c1 = arith.constant 1 : index
60+
// CHECK-NEXT: %0 = memref.alloca() : memref<f32>
61+
// CHECK-NEXT: %1 = arith.cmpi slt, %arg1, %arg2 : index
62+
// CHECK-NEXT: scf.if %1 {
63+
// CHECK-NEXT: memref.store %cst, %0[] : memref<f32>
64+
// CHECK-NEXT: %2 = memref.load %0[] : memref<f32>
65+
// CHECK-NEXT: scf.parallel (%arg3) = (%arg1) to (%arg2) step (%c1) {
66+
// CHECK-NEXT: call @use(%2) : (f32) -> ()
67+
// CHECK-NEXT: scf.yield
68+
// CHECK-NEXT: }
69+
// CHECK-NEXT: }
70+
// CHECK-NEXT: return
71+
// CHECK-NEXT: }
72+
73+
// CHECK: func @nohoist(%arg0: memref<?xf32>, %arg1: index, %arg2: index) {
74+
// CHECK-NEXT: %c1 = arith.constant 1 : index
75+
// CHECK-NEXT: %0 = memref.alloca() : memref<f32>
76+
// CHECK-NEXT: scf.parallel (%arg3) = (%arg1) to (%arg2) step (%c1) {
77+
// CHECK-NEXT: %1 = call @get() : () -> f32
78+
// CHECK-NEXT: memref.store %1, %0[] : memref<f32>
79+
// CHECK-NEXT: %2 = memref.load %0[] : memref<f32>
80+
// CHECK-NEXT: call @use(%2) : (f32) -> ()
81+
// CHECK-NEXT: scf.yield
82+
// CHECK-NEXT: }
83+
// CHECK-NEXT: return
84+
// CHECK-NEXT: }

0 commit comments

Comments
 (0)