Skip to content

Commit f7b222b

Browse files
authored
Fix parallel lower preinline (#313)
1 parent d8b6e4f commit f7b222b

File tree

2 files changed

+111
-40
lines changed

2 files changed

+111
-40
lines changed

lib/polygeist/Passes/ParallelLower.cpp

Lines changed: 81 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,6 @@ void ParallelLower::runOnOperation() {
208208

209209
SymbolTableCollection symbolTable;
210210
symbolTable.getSymbolTable(getOperation());
211-
SymbolUserMap symbolUserMap(symbolTable, getOperation());
212211

213212
getOperation()->walk([&](CallOp bidx) {
214213
if (bidx.getCallee() == "cudaThreadSynchronize")
@@ -336,52 +335,94 @@ void ParallelLower::runOnOperation() {
336335
callInliner(op);
337336
}
338337

339-
// Only supports single block functions at the moment.
338+
{
340339

341-
SmallVector<std::pair<Operation *, size_t>> outlineOps;
342-
getOperation().walk([&](gpu::LaunchOp launchOp) {
343-
launchOp.walk([&](LLVM::CallOp caller) {
344-
if (!caller.getCallee()) {
345-
outlineOps.push_back(std::make_pair(caller, (size_t)0));
346-
}
347-
});
348-
});
349-
SetVector<FunctionOpInterface> toinl;
350-
while (outlineOps.size()) {
351-
auto opv = outlineOps.back();
352-
auto op = std::get<0>(opv);
353-
auto idx = std::get<1>(opv);
354-
outlineOps.pop_back();
355-
if (Value fn = op->getOperand(idx)) {
356-
if (auto fn2 = fn.getDefiningOp<polygeist::Memref2PointerOp>())
357-
fn = fn2.getOperand();
358-
if (auto ba = fn.dyn_cast<BlockArgument>()) {
359-
if (auto F =
360-
dyn_cast<FunctionOpInterface>(ba.getOwner()->getParentOp())) {
361-
if (toinl.count(F))
362-
continue;
363-
toinl.insert(F);
364-
for (Operation *m : symbolUserMap.getUsers(F)) {
365-
outlineOps.push_back(std::make_pair(m, (size_t)ba.getArgNumber()));
340+
SmallVector<Operation *> inlineOps;
341+
SmallVector<mlir::Value> toFollowOps;
342+
SetVector<FunctionOpInterface> toinl;
343+
344+
getOperation().walk(
345+
[&](mlir::gpu::ThreadIdOp bidx) { inlineOps.push_back(bidx); });
346+
getOperation().walk(
347+
[&](mlir::gpu::GridDimOp bidx) { inlineOps.push_back(bidx); });
348+
getOperation().walk(
349+
[&](mlir::NVVM::Barrier0Op bidx) { inlineOps.push_back(bidx); });
350+
351+
SymbolUserMap symbolUserMap(symbolTable, getOperation());
352+
while (inlineOps.size()) {
353+
auto op = inlineOps.back();
354+
inlineOps.pop_back();
355+
auto lop = op->getParentOfType<gpu::LaunchOp>();
356+
auto fop = op->getParentOfType<FunctionOpInterface>();
357+
if (!lop || lop->isAncestor(fop)) {
358+
toinl.insert(fop);
359+
for (Operation *m : symbolUserMap.getUsers(fop)) {
360+
if (isa<LLVM::CallOp, func::CallOp>(m))
361+
inlineOps.push_back(m);
362+
else if (isa<polygeist::GetFuncOp>(m)) {
363+
toFollowOps.push_back(m->getResult(0));
366364
}
367365
}
368366
}
369367
}
370-
}
371-
for (auto F : toinl) {
372-
for (Operation *m : symbolUserMap.getUsers(F)) {
373-
callInliner(cast<CallOp>(m));
368+
for (auto F : toinl) {
369+
SmallVector<LLVM::CallOp> ltoinl;
370+
SmallVector<func::CallOp> mtoinl;
371+
SymbolUserMap symbolUserMap(symbolTable, getOperation());
372+
for (Operation *m : symbolUserMap.getUsers(F)) {
373+
if (auto l = dyn_cast<LLVM::CallOp>(m))
374+
ltoinl.push_back(l);
375+
else if (auto mc = dyn_cast<func::CallOp>(m))
376+
mtoinl.push_back(mc);
377+
}
378+
for (auto l : ltoinl) {
379+
LLVMcallInliner(l);
380+
}
381+
for (auto m : mtoinl) {
382+
callInliner(m);
383+
}
384+
}
385+
while (toFollowOps.size()) {
386+
auto op = toFollowOps.back();
387+
toFollowOps.pop_back();
388+
SmallVector<LLVM::CallOp> ltoinl;
389+
SmallVector<func::CallOp> mtoinl;
390+
bool inlined = false;
391+
for (auto u : op.getUsers()) {
392+
if (auto cop = dyn_cast<LLVM::CallOp>(u)) {
393+
if (!cop.getCallee() && cop->getOperand(0) == op) {
394+
OpBuilder builder(cop);
395+
SmallVector<Value> vals;
396+
if (fixupGetFunc(cop, builder, vals).succeeded()) {
397+
if (vals.size())
398+
cop.getResult().replaceAllUsesWith(vals[0]);
399+
cop.erase();
400+
inlined = true;
401+
break;
402+
}
403+
} else if (cop.getCallee())
404+
ltoinl.push_back(cop);
405+
} else if (auto cop = dyn_cast<func::CallOp>(u)) {
406+
mtoinl.push_back(cop);
407+
} else {
408+
for (auto r : u->getResults())
409+
toFollowOps.push_back(r);
410+
}
411+
}
412+
for (auto l : ltoinl) {
413+
LLVMcallInliner(l);
414+
inlined = true;
415+
}
416+
for (auto m : mtoinl) {
417+
callInliner(m);
418+
inlined = true;
419+
}
420+
if (inlined)
421+
toFollowOps.push_back(op);
374422
}
375423
}
376-
getOperation().walk([&](LLVM::CallOp caller) {
377-
OpBuilder builder(caller);
378-
SmallVector<Value> vals;
379-
if (fixupGetFunc(caller, builder, vals).failed())
380-
return;
381-
if (vals.size())
382-
caller.getResult().replaceAllUsesWith(vals[0]);
383-
caller.erase();
384-
});
424+
425+
// Only supports single block functions at the moment.
385426

386427
SmallVector<gpu::LaunchOp> toHandle;
387428
getOperation().walk(
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
// RUN: cgeist %s --cuda-gpu-arch=sm_60 -nocudalib -nocudainc %resourcedir --function=* --cuda-lower --cpuify="distribute" -S | FileCheck %s
2+
3+
#include "Inputs/cuda.h"
4+
#include "__clang_cuda_builtin_vars.h"
5+
6+
#define N 20
7+
8+
__device__ void bar(double* w) {
9+
w[threadIdx.x] = 2.0;
10+
}
11+
12+
__global__ void foo(double * w) {
13+
bar(w);
14+
}
15+
16+
void something(double*);
17+
18+
template<typename T>
19+
void templ(T fn, double *w) {
20+
something(w);
21+
}
22+
23+
void start(double* w) {
24+
templ(foo, w);
25+
}
26+
27+
// CHECK: func.func @_Z5startPd(%arg0: memref<?xf64>)
28+
// CHECK-NEXT: call @_Z9somethingPd(%arg0) : (memref<?xf64>) -> ()
29+
// CHECK-NEXT: return
30+
// CHECK-NEXT: }

0 commit comments

Comments
 (0)