@@ -208,7 +208,6 @@ void ParallelLower::runOnOperation() {
208
208
209
209
SymbolTableCollection symbolTable;
210
210
symbolTable.getSymbolTable (getOperation ());
211
- SymbolUserMap symbolUserMap (symbolTable, getOperation ());
212
211
213
212
getOperation ()->walk ([&](CallOp bidx) {
214
213
if (bidx.getCallee () == " cudaThreadSynchronize" )
@@ -336,52 +335,94 @@ void ParallelLower::runOnOperation() {
336
335
callInliner (op);
337
336
}
338
337
339
- // Only supports single block functions at the moment.
338
+ {
340
339
341
- SmallVector<std::pair<Operation *, size_t >> outlineOps;
342
- getOperation ().walk ([&](gpu::LaunchOp launchOp) {
343
- launchOp.walk ([&](LLVM::CallOp caller) {
344
- if (!caller.getCallee ()) {
345
- outlineOps.push_back (std::make_pair (caller, (size_t )0 ));
346
- }
347
- });
348
- });
349
- SetVector<FunctionOpInterface> toinl;
350
- while (outlineOps.size ()) {
351
- auto opv = outlineOps.back ();
352
- auto op = std::get<0 >(opv);
353
- auto idx = std::get<1 >(opv);
354
- outlineOps.pop_back ();
355
- if (Value fn = op->getOperand (idx)) {
356
- if (auto fn2 = fn.getDefiningOp <polygeist::Memref2PointerOp>())
357
- fn = fn2.getOperand ();
358
- if (auto ba = fn.dyn_cast <BlockArgument>()) {
359
- if (auto F =
360
- dyn_cast<FunctionOpInterface>(ba.getOwner ()->getParentOp ())) {
361
- if (toinl.count (F))
362
- continue ;
363
- toinl.insert (F);
364
- for (Operation *m : symbolUserMap.getUsers (F)) {
365
- outlineOps.push_back (std::make_pair (m, (size_t )ba.getArgNumber ()));
340
+ SmallVector<Operation *> inlineOps;
341
+ SmallVector<mlir::Value> toFollowOps;
342
+ SetVector<FunctionOpInterface> toinl;
343
+
344
+ getOperation ().walk (
345
+ [&](mlir::gpu::ThreadIdOp bidx) { inlineOps.push_back (bidx); });
346
+ getOperation ().walk (
347
+ [&](mlir::gpu::GridDimOp bidx) { inlineOps.push_back (bidx); });
348
+ getOperation ().walk (
349
+ [&](mlir::NVVM::Barrier0Op bidx) { inlineOps.push_back (bidx); });
350
+
351
+ SymbolUserMap symbolUserMap (symbolTable, getOperation ());
352
+ while (inlineOps.size ()) {
353
+ auto op = inlineOps.back ();
354
+ inlineOps.pop_back ();
355
+ auto lop = op->getParentOfType <gpu::LaunchOp>();
356
+ auto fop = op->getParentOfType <FunctionOpInterface>();
357
+ if (!lop || lop->isAncestor (fop)) {
358
+ toinl.insert (fop);
359
+ for (Operation *m : symbolUserMap.getUsers (fop)) {
360
+ if (isa<LLVM::CallOp, func::CallOp>(m))
361
+ inlineOps.push_back (m);
362
+ else if (isa<polygeist::GetFuncOp>(m)) {
363
+ toFollowOps.push_back (m->getResult (0 ));
366
364
}
367
365
}
368
366
}
369
367
}
370
- }
371
- for (auto F : toinl) {
372
- for (Operation *m : symbolUserMap.getUsers (F)) {
373
- callInliner (cast<CallOp>(m));
368
+ for (auto F : toinl) {
369
+ SmallVector<LLVM::CallOp> ltoinl;
370
+ SmallVector<func::CallOp> mtoinl;
371
+ SymbolUserMap symbolUserMap (symbolTable, getOperation ());
372
+ for (Operation *m : symbolUserMap.getUsers (F)) {
373
+ if (auto l = dyn_cast<LLVM::CallOp>(m))
374
+ ltoinl.push_back (l);
375
+ else if (auto mc = dyn_cast<func::CallOp>(m))
376
+ mtoinl.push_back (mc);
377
+ }
378
+ for (auto l : ltoinl) {
379
+ LLVMcallInliner (l);
380
+ }
381
+ for (auto m : mtoinl) {
382
+ callInliner (m);
383
+ }
384
+ }
385
+ while (toFollowOps.size ()) {
386
+ auto op = toFollowOps.back ();
387
+ toFollowOps.pop_back ();
388
+ SmallVector<LLVM::CallOp> ltoinl;
389
+ SmallVector<func::CallOp> mtoinl;
390
+ bool inlined = false ;
391
+ for (auto u : op.getUsers ()) {
392
+ if (auto cop = dyn_cast<LLVM::CallOp>(u)) {
393
+ if (!cop.getCallee () && cop->getOperand (0 ) == op) {
394
+ OpBuilder builder (cop);
395
+ SmallVector<Value> vals;
396
+ if (fixupGetFunc (cop, builder, vals).succeeded ()) {
397
+ if (vals.size ())
398
+ cop.getResult ().replaceAllUsesWith (vals[0 ]);
399
+ cop.erase ();
400
+ inlined = true ;
401
+ break ;
402
+ }
403
+ } else if (cop.getCallee ())
404
+ ltoinl.push_back (cop);
405
+ } else if (auto cop = dyn_cast<func::CallOp>(u)) {
406
+ mtoinl.push_back (cop);
407
+ } else {
408
+ for (auto r : u->getResults ())
409
+ toFollowOps.push_back (r);
410
+ }
411
+ }
412
+ for (auto l : ltoinl) {
413
+ LLVMcallInliner (l);
414
+ inlined = true ;
415
+ }
416
+ for (auto m : mtoinl) {
417
+ callInliner (m);
418
+ inlined = true ;
419
+ }
420
+ if (inlined)
421
+ toFollowOps.push_back (op);
374
422
}
375
423
}
376
- getOperation ().walk ([&](LLVM::CallOp caller) {
377
- OpBuilder builder (caller);
378
- SmallVector<Value> vals;
379
- if (fixupGetFunc (caller, builder, vals).failed ())
380
- return ;
381
- if (vals.size ())
382
- caller.getResult ().replaceAllUsesWith (vals[0 ]);
383
- caller.erase ();
384
- });
424
+
425
+ // Only supports single block functions at the moment.
385
426
386
427
SmallVector<gpu::LaunchOp> toHandle;
387
428
getOperation ().walk (
0 commit comments