27
27
#include " mlir/Transforms/Passes.h"
28
28
#include " polygeist/Ops.h"
29
29
#include " polygeist/Passes/Passes.h"
30
+ #include " llvm/ADT/SetVector.h"
30
31
#include " llvm/ADT/SmallPtrSet.h"
31
32
#include < algorithm>
32
33
#include < mutex>
@@ -198,18 +199,23 @@ mlir::LLVM::LLVMFuncOp GetOrCreateFreeFunction(ModuleOp module) {
198
199
lnk);
199
200
}
200
201
202
+ LogicalResult fixupGetFunc (LLVM::CallOp, OpBuilder &rewriter,
203
+ SmallVectorImpl<Value> &);
204
+
201
205
void ParallelLower::runOnOperation () {
202
206
// The inliner should only be run on operations that define a symbol table,
203
207
// as the callgraph will need to resolve references.
204
208
205
209
SymbolTableCollection symbolTable;
206
210
symbolTable.getSymbolTable (getOperation ());
211
+ SymbolUserMap symbolUserMap (symbolTable, getOperation ());
207
212
208
213
getOperation ()->walk ([&](CallOp bidx) {
209
214
if (bidx.getCallee () == " cudaThreadSynchronize" )
210
215
bidx.erase ();
211
216
});
212
217
218
+ std::function<void (LLVM::CallOp)> LLVMcallInliner;
213
219
std::function<void (CallOp)> callInliner = [&](CallOp caller) {
214
220
// Build the inliner interface.
215
221
AlwaysInlinerInterface interface (&getContext ());
@@ -230,10 +236,72 @@ void ParallelLower::runOnOperation() {
230
236
return ;
231
237
if (targetRegion->empty ())
232
238
return ;
233
- SmallVector<CallOp> ops;
234
- callableOp.walk ([&](CallOp caller) { ops.push_back (caller); });
235
- for (auto op : ops)
236
- callInliner (op);
239
+ {
240
+ SmallVector<CallOp> ops;
241
+ callableOp.walk ([&](CallOp caller) { ops.push_back (caller); });
242
+ for (auto op : ops)
243
+ callInliner (op);
244
+ }
245
+ {
246
+ SmallVector<LLVM::CallOp> ops;
247
+ callableOp.walk ([&](LLVM::CallOp caller) { ops.push_back (caller); });
248
+ for (auto op : ops)
249
+ LLVMcallInliner (op);
250
+ }
251
+ OpBuilder b (caller);
252
+ auto allocScope = b.create <memref::AllocaScopeOp>(caller.getLoc (),
253
+ caller.getResultTypes ());
254
+ allocScope.getRegion ().push_back (new Block ());
255
+ b.setInsertionPointToStart (&allocScope.getRegion ().front ());
256
+ auto exOp = b.create <scf::ExecuteRegionOp>(caller.getLoc (),
257
+ caller.getResultTypes ());
258
+ Block *blk = new Block ();
259
+ exOp.getRegion ().push_back (blk);
260
+ caller->moveBefore (blk, blk->begin ());
261
+ caller.replaceAllUsesWith (allocScope.getResults ());
262
+ b.setInsertionPointToEnd (blk);
263
+ b.create <scf::YieldOp>(caller.getLoc (), caller.getResults ());
264
+ if (inlineCall (interface, caller, callableOp, targetRegion,
265
+ /* shouldCloneInlinedRegion=*/ true )
266
+ .succeeded ()) {
267
+ caller.erase ();
268
+ }
269
+ b.setInsertionPointToEnd (&allocScope.getRegion ().front ());
270
+ b.create <memref::AllocaScopeReturnOp>(allocScope.getLoc (),
271
+ exOp.getResults ());
272
+ };
273
+ LLVMcallInliner = [&](LLVM::CallOp caller) {
274
+ // Build the inliner interface.
275
+ AlwaysInlinerInterface interface (&getContext ());
276
+
277
+ auto callable = caller.getCallableForCallee ();
278
+ CallableOpInterface callableOp;
279
+ if (SymbolRefAttr symRef = callable.dyn_cast <SymbolRefAttr>()) {
280
+ if (!symRef.isa <FlatSymbolRefAttr>())
281
+ return ;
282
+ auto *symbolOp =
283
+ symbolTable.lookupNearestSymbolFrom (getOperation (), symRef);
284
+ callableOp = dyn_cast_or_null<CallableOpInterface>(symbolOp);
285
+ } else {
286
+ return ;
287
+ }
288
+ Region *targetRegion = callableOp.getCallableRegion ();
289
+ if (!targetRegion)
290
+ return ;
291
+ if (targetRegion->empty ())
292
+ return ;
293
+ {
294
+ SmallVector<CallOp> ops;
295
+ callableOp.walk ([&](CallOp caller) { ops.push_back (caller); });
296
+ for (auto op : ops)
297
+ callInliner (op);
298
+ }
299
+ {
300
+ SmallVector<LLVM::CallOp> ops;
301
+ callableOp.walk ([&](LLVM::CallOp caller) { ops.push_back (caller); });
302
+ for (auto op : ops)
303
+ LLVMcallInliner (op);
304
+ }
237
305
OpBuilder b (caller);
238
306
auto allocScope = b.create <memref::AllocaScopeOp>(caller.getLoc (),
239
307
caller.getResultTypes ());
@@ -256,6 +324,7 @@ void ParallelLower::runOnOperation() {
256
324
b.create <memref::AllocaScopeReturnOp>(allocScope.getLoc (),
257
325
exOp.getResults ());
258
326
};
327
+
259
328
{
260
329
SmallVector<CallOp> dimsToInline;
261
330
getOperation ()->walk ([&](CallOp bidx) {
@@ -268,15 +337,68 @@ void ParallelLower::runOnOperation() {
268
337
}
269
338
270
339
// Only supports single block functions at the moment.
340
+
341
+ SmallVector<std::pair<Operation *, size_t >> outlineOps;
342
+ getOperation ().walk ([&](gpu::LaunchOp launchOp) {
343
+ launchOp.walk ([&](LLVM::CallOp caller) {
344
+ if (!caller.getCallee ()) {
345
+ outlineOps.push_back (std::make_pair (caller, (size_t )0 ));
346
+ }
347
+ });
348
+ });
349
+ SetVector<FunctionOpInterface> toinl;
350
+ while (outlineOps.size ()) {
351
+ auto opv = outlineOps.back ();
352
+ auto op = std::get<0 >(opv);
353
+ auto idx = std::get<1 >(opv);
354
+ outlineOps.pop_back ();
355
+ if (Value fn = op->getOperand (idx)) {
356
+ if (auto fn2 = fn.getDefiningOp <polygeist::Memref2PointerOp>())
357
+ fn = fn2.getOperand ();
358
+ if (auto ba = fn.dyn_cast <BlockArgument>()) {
359
+ if (auto F =
360
+ dyn_cast<FunctionOpInterface>(ba.getOwner ()->getParentOp ())) {
361
+ if (toinl.count (F))
362
+ continue ;
363
+ toinl.insert (F);
364
+ for (Operation *m : symbolUserMap.getUsers (F)) {
365
+ outlineOps.push_back (std::make_pair (m, (size_t )ba.getArgNumber ()));
366
+ }
367
+ }
368
+ }
369
+ }
370
+ }
371
+ for (auto F : toinl) {
372
+ for (Operation *m : symbolUserMap.getUsers (F)) {
373
+ callInliner (cast<CallOp>(m));
374
+ }
375
+ }
376
+ getOperation ().walk ([&](LLVM::CallOp caller) {
377
+ OpBuilder builder (caller);
378
+ SmallVector<Value> vals;
379
+ if (fixupGetFunc (caller, builder, vals).failed ())
380
+ return ;
381
+ if (vals.size ())
382
+ caller.getResult ().replaceAllUsesWith (vals[0 ]);
383
+ caller.erase ();
384
+ });
385
+
271
386
SmallVector<gpu::LaunchOp> toHandle;
272
387
getOperation ().walk (
273
388
[&](gpu::LaunchOp launchOp) { toHandle.push_back (launchOp); });
274
-
275
389
for (gpu::LaunchOp launchOp : toHandle) {
276
- SmallVector<CallOp> ops;
277
- launchOp.walk ([&](CallOp caller) { ops.push_back (caller); });
278
- for (auto op : ops)
279
- callInliner (op);
390
+ {
391
+ SmallVector<CallOp> ops;
392
+ launchOp.walk ([&](CallOp caller) { ops.push_back (caller); });
393
+ for (auto op : ops)
394
+ callInliner (op);
395
+ }
396
+ {
397
+ SmallVector<LLVM::CallOp> lops;
398
+ launchOp.walk ([&](LLVM::CallOp caller) { lops.push_back (caller); });
399
+ for (auto op : lops)
400
+ LLVMcallInliner (op);
401
+ }
280
402
281
403
mlir::IRRewriter builder (launchOp.getContext ());
282
404
auto loc = launchOp.getLoc ();
0 commit comments