11
11
// ===----------------------------------------------------------------------===//
12
12
#include " PassDetails.h"
13
13
14
+ #include " mlir/../../lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp"
14
15
#include " mlir/Analysis/DataLayoutAnalysis.h"
15
16
#include " mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
16
17
#include " mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
@@ -350,12 +351,13 @@ struct LLVMOpLowering : public ConversionPattern {
350
351
state.addRegion ();
351
352
352
353
Operation *rewritten = rewriter.create (state);
353
- rewriter.replaceOp (op, rewritten->getResults ());
354
354
355
355
for (unsigned i = 0 , e = op->getNumRegions (); i < e; ++i)
356
356
rewriter.inlineRegionBefore (op->getRegion (i), rewritten->getRegion (i),
357
357
rewritten->getRegion (i).begin ());
358
358
359
+ rewriter.replaceOp (op, rewritten->getResults ());
360
+
359
361
return success ();
360
362
}
361
363
};
@@ -407,6 +409,35 @@ static LLVM::LLVMFuncOp addMocCUDAFunction(ModuleOp module, Type streamTy) {
407
409
return resumeOp;
408
410
}
409
411
412
+ // / In some cases such as scf.for, the blocks generated when it gets lowered
413
+ // / depend on the parent region having already been lowered and having a
414
+ // / converter assigned to it - this pattern assures that execute ops have a
415
+ // / converter becaus they will actually be lowered only after everything else
416
+ // / has been converted to llvm
417
+ class ConvertExecuteOpTypes : public ConvertOpToLLVMPattern <async::ExecuteOp> {
418
+ public:
419
+ using ConvertOpToLLVMPattern<async::ExecuteOp>::ConvertOpToLLVMPattern;
420
+ LogicalResult
421
+ matchAndRewrite (async::ExecuteOp op, OpAdaptor adaptor,
422
+ ConversionPatternRewriter &rewriter) const override {
423
+ async::ExecuteOp newOp = cast<async::ExecuteOp>(
424
+ rewriter.cloneWithoutRegions (*op.getOperation ()));
425
+ rewriter.inlineRegionBefore (op.getRegion (), newOp.getRegion (),
426
+ newOp.getRegion ().end ());
427
+
428
+ // Set operands and update block argument and result types.
429
+ newOp->setOperands (adaptor.getOperands ());
430
+ if (failed (rewriter.convertRegionTypes (&newOp.getRegion (), *typeConverter)))
431
+ return failure ();
432
+ for (auto result : newOp.getResults ())
433
+ result.setType (typeConverter->convertType (result.getType ()));
434
+
435
+ newOp->setAttr (" polygeist.handled" , rewriter.getUnitAttr ());
436
+ rewriter.replaceOp (op, newOp.getResults ());
437
+ return success ();
438
+ }
439
+ };
440
+
410
441
struct AsyncOpLowering : public ConvertOpToLLVMPattern <async::ExecuteOp> {
411
442
using ConvertOpToLLVMPattern<async::ExecuteOp>::ConvertOpToLLVMPattern;
412
443
@@ -423,12 +454,12 @@ struct AsyncOpLowering : public ConvertOpToLLVMPattern<async::ExecuteOp> {
423
454
424
455
// Make sure that all constants will be inside the outlined async function
425
456
// to reduce the number of function arguments.
426
- Region &funcReg = execute.getBodyRegion ();
457
+ Region &execReg = execute.getBodyRegion ();
427
458
428
459
// Collect all outlined function inputs.
429
460
SetVector<mlir::Value> functionInputs;
430
461
431
- getUsedValuesDefinedAbove (execute.getBodyRegion (), funcReg , functionInputs);
462
+ getUsedValuesDefinedAbove (execute.getBodyRegion (), execReg , functionInputs);
432
463
SmallVector<Value> toErase;
433
464
for (auto a : functionInputs) {
434
465
Operation *op = a.getDefiningOp ();
@@ -451,16 +482,18 @@ struct AsyncOpLowering : public ConvertOpToLLVMPattern<async::ExecuteOp> {
451
482
452
483
// TODO: Derive outlined function name from the parent FuncOp (support
453
484
// multiple nested async.execute operations).
454
- auto moduleBuilder =
455
- ImplicitLocOpBuilder::atBlockEnd (loc, module .getBody ());
456
-
457
- static int off = 0 ;
458
- off++;
459
- auto func = moduleBuilder.create <LLVM::LLVMFuncOp>(
460
- execute.getLoc (),
461
- " kernelbody." + std::to_string ((long long int )&execute) + " ." +
462
- std::to_string (off),
463
- funcType);
485
+ LLVM::LLVMFuncOp func;
486
+ {
487
+ OpBuilder::InsertionGuard guard (rewriter);
488
+ rewriter.setInsertionPointToEnd (module .getBody ());
489
+ static int off = 0 ;
490
+ off++;
491
+ func = rewriter.create <LLVM::LLVMFuncOp>(
492
+ execute.getLoc (),
493
+ " kernelbody." + std::to_string ((long long int )&execute) + " ." +
494
+ std::to_string (off),
495
+ funcType);
496
+ }
464
497
465
498
rewriter.setInsertionPointToStart (func.addEntryBlock ());
466
499
BlockAndValueMapping valueMapping;
@@ -522,10 +555,17 @@ struct AsyncOpLowering : public ConvertOpToLLVMPattern<async::ExecuteOp> {
522
555
523
556
// Clone all operations from the execute operation body into the outlined
524
557
// function body.
525
- for (Operation &op : execute.getBody ()->without_terminator ())
526
- rewriter.clone (op, valueMapping);
527
-
528
- rewriter.create <LLVM::ReturnOp>(execute.getLoc (), ValueRange ());
558
+ rewriter.cloneRegionBefore (execute.getBodyRegion (), func.getRegion (),
559
+ func.getRegion ().end (), valueMapping);
560
+ rewriter.create <LLVM::BrOp>(execute.getLoc (), ValueRange (),
561
+ &*std::next (func.getRegion ().begin ()));
562
+ for (Block &b : func.getRegion ()) {
563
+ auto term = b.getTerminator ();
564
+ if (isa<async::YieldOp>(term)) {
565
+ rewriter.setInsertionPointToEnd (&b);
566
+ rewriter.replaceOpWithNewOp <LLVM::ReturnOp>(term, ValueRange ());
567
+ }
568
+ }
529
569
}
530
570
531
571
// Replace the original `async.execute` with a call to outlined function.
@@ -703,7 +743,7 @@ struct AllocLikeOpLowering : public ConvertOpToLLVMPattern<OpTy> {
703
743
};
704
744
705
745
// / Pattern for lowering automatic stack allocations.
706
- struct AllocaOpLowering : public AllocLikeOpLowering <memref::AllocaOp> {
746
+ struct CAllocaOpLowering : public AllocLikeOpLowering <memref::AllocaOp> {
707
747
public:
708
748
using AllocLikeOpLowering<memref::AllocaOp>::AllocLikeOpLowering;
709
749
@@ -729,7 +769,7 @@ struct AllocaOpLowering : public AllocLikeOpLowering<memref::AllocaOp> {
729
769
};
730
770
731
771
// / Pattern for lowering heap allocations via malloc.
732
- struct AllocOpLowering : public AllocLikeOpLowering <memref::AllocOp> {
772
+ struct CAllocOpLowering : public AllocLikeOpLowering <memref::AllocOp> {
733
773
public:
734
774
using AllocLikeOpLowering<memref::AllocOp>::AllocLikeOpLowering;
735
775
@@ -783,7 +823,7 @@ struct AllocOpLowering : public AllocLikeOpLowering<memref::AllocOp> {
783
823
};
784
824
785
825
// / Pattern for lowering heap deallocations via free.
786
- struct DeallocOpLowering : public ConvertOpToLLVMPattern <memref::DeallocOp> {
826
+ struct CDeallocOpLowering : public ConvertOpToLLVMPattern <memref::DeallocOp> {
787
827
public:
788
828
using ConvertOpToLLVMPattern<memref::DeallocOp>::ConvertOpToLLVMPattern;
789
829
@@ -914,7 +954,7 @@ struct GetGlobalOpLowering
914
954
915
955
// / Base class for patterns lowering memory access operations.
916
956
template <typename OpTy>
917
- struct LoadStoreOpLowering : public ConvertOpToLLVMPattern <OpTy> {
957
+ struct CLoadStoreOpLowering : public ConvertOpToLLVMPattern <OpTy> {
918
958
protected:
919
959
using ConvertOpToLLVMPattern<OpTy>::ConvertOpToLLVMPattern;
920
960
@@ -941,9 +981,9 @@ struct LoadStoreOpLowering : public ConvertOpToLLVMPattern<OpTy> {
941
981
};
942
982
943
983
// / Pattern for lowering a memory load.
944
- struct LoadOpLowering : public LoadStoreOpLowering <memref::LoadOp> {
984
+ struct CLoadOpLowering : public CLoadStoreOpLowering <memref::LoadOp> {
945
985
public:
946
- using LoadStoreOpLowering <memref::LoadOp>::LoadStoreOpLowering ;
986
+ using CLoadStoreOpLowering <memref::LoadOp>::CLoadStoreOpLowering ;
947
987
948
988
LogicalResult
949
989
matchAndRewrite (memref::LoadOp loadOp, OpAdaptor adaptor,
@@ -958,9 +998,9 @@ struct LoadOpLowering : public LoadStoreOpLowering<memref::LoadOp> {
958
998
};
959
999
960
1000
// / Pattern for lowering a memory store.
961
- struct StoreOpLowering : public LoadStoreOpLowering <memref::StoreOp> {
1001
+ struct CStoreOpLowering : public CLoadStoreOpLowering <memref::StoreOp> {
962
1002
public:
963
- using LoadStoreOpLowering <memref::StoreOp>::LoadStoreOpLowering ;
1003
+ using CLoadStoreOpLowering <memref::StoreOp>::CLoadStoreOpLowering ;
964
1004
965
1005
LogicalResult
966
1006
matchAndRewrite (memref::StoreOp storeOp, OpAdaptor adaptor,
@@ -1242,9 +1282,9 @@ struct ReturnOpLowering : public ConvertOpToLLVMPattern<func::ReturnOp> {
1242
1282
static void
1243
1283
populateCStyleMemRefLoweringPatterns (RewritePatternSet &patterns,
1244
1284
LLVMTypeConverter &typeConverter) {
1245
- patterns.add <AllocaOpLowering, AllocOpLowering, DeallocOpLowering ,
1246
- GetGlobalOpLowering, GlobalOpLowering, LoadOpLowering ,
1247
- StoreOpLowering >(typeConverter);
1285
+ patterns.add <CAllocaOpLowering, CAllocOpLowering, CDeallocOpLowering ,
1286
+ GetGlobalOpLowering, GlobalOpLowering, CLoadOpLowering ,
1287
+ CStoreOpLowering, AllocaScopeOpLowering >(typeConverter);
1248
1288
}
1249
1289
1250
1290
// / Appends the patterns lowering operations from the Func dialect to the LLVM
@@ -1292,40 +1332,42 @@ struct ConvertPolygeistToLLVMPass
1292
1332
1293
1333
options.dataLayout = llvm::DataLayout (this ->dataLayout );
1294
1334
1295
- for (int i = 0 ; i < 2 ; i++) {
1335
+ // Define the type converter. Override the default behavior for memrefs if
1336
+ // requested.
1337
+ LLVMTypeConverter converter (&getContext (), options, &dataLayoutAnalysis);
1338
+ if (useCStyleMemRef) {
1339
+ converter.addConversion ([&](MemRefType type) -> Optional<Type> {
1340
+ Type converted = converter.convertType (type.getElementType ());
1341
+ if (!converted)
1342
+ return Type ();
1296
1343
1297
- // Define the type converter. Override the default behavior for memrefs if
1298
- // requested.
1299
- LLVMTypeConverter converter (&getContext (), options, &dataLayoutAnalysis);
1300
- if (useCStyleMemRef) {
1301
- converter.addConversion ([&](MemRefType type) -> Optional<Type> {
1302
- Type converted = converter.convertType (type.getElementType ());
1303
- if (!converted)
1304
- return Type ();
1305
-
1306
- if (type.getRank () == 0 ) {
1307
- return LLVM::LLVMPointerType::get (converted,
1308
- type.getMemorySpaceAsInt ());
1309
- }
1310
-
1311
- // Only the leading dimension can be dynamic.
1312
- if (llvm::any_of (type.getShape ().drop_front (), ShapedType::isDynamic))
1313
- return Type ();
1314
-
1315
- // Only identity layout is supported.
1316
- // TODO: detect the strided layout that is equivalent to identity
1317
- // given the static part of the shape.
1318
- if (!type.getLayout ().isIdentity ())
1319
- return Type ();
1320
-
1321
- if (type.getRank () > 0 ) {
1322
- for (int64_t size : llvm::reverse (type.getShape ().drop_front ()))
1323
- converted = LLVM::LLVMArrayType::get (converted, size);
1324
- }
1344
+ if (type.getRank () == 0 ) {
1325
1345
return LLVM::LLVMPointerType::get (converted,
1326
1346
type.getMemorySpaceAsInt ());
1327
- });
1328
- }
1347
+ }
1348
+
1349
+ // Only the leading dimension can be dynamic.
1350
+ if (llvm::any_of (type.getShape ().drop_front (), ShapedType::isDynamic))
1351
+ return Type ();
1352
+
1353
+ // Only identity layout is supported.
1354
+ // TODO: detect the strided layout that is equivalent to identity
1355
+ // given the static part of the shape.
1356
+ if (!type.getLayout ().isIdentity ())
1357
+ return Type ();
1358
+
1359
+ if (type.getRank () > 0 ) {
1360
+ for (int64_t size : llvm::reverse (type.getShape ().drop_front ()))
1361
+ converted = LLVM::LLVMArrayType::get (converted, size);
1362
+ }
1363
+ return LLVM::LLVMPointerType::get (converted,
1364
+ type.getMemorySpaceAsInt ());
1365
+ });
1366
+ }
1367
+
1368
+ converter.addConversion ([&](async::TokenType type) { return type; });
1369
+
1370
+ for (int i = 0 ; i < 2 ; i++) {
1329
1371
1330
1372
RewritePatternSet patterns (&getContext ());
1331
1373
populatePolygeistToLLVMConversionPatterns (converter, patterns);
@@ -1343,8 +1385,6 @@ struct ConvertPolygeistToLLVMPass
1343
1385
populateOpenMPToLLVMConversionPatterns (converter, patterns);
1344
1386
arith::populateArithToLLVMConversionPatterns (converter, patterns);
1345
1387
1346
- converter.addConversion ([&](async::TokenType type) { return type; });
1347
-
1348
1388
patterns.add <LLVMOpLowering, GlobalOpTypeConversion,
1349
1389
ReturnOpTypeConversion, GetFuncOpConversion>(converter);
1350
1390
patterns.add <URLLVMOpLowering>(converter);
@@ -1399,10 +1439,16 @@ struct ConvertPolygeistToLLVMPass
1399
1439
op->getResult(0).getType(); });
1400
1440
*/
1401
1441
1402
- if (i == 1 ) {
1442
+ if (i == 0 ) {
1443
+ patterns.add <ConvertExecuteOpTypes>(converter);
1444
+ target.addDynamicallyLegalOp <async::ExecuteOp>(
1445
+ [&](async::ExecuteOp eo) {
1446
+ return eo->hasAttr (" polygeist.handled" );
1447
+ });
1448
+ } else if (i == 1 ) {
1403
1449
// target.addIllegalOp<UnrealizedConversionCastOp>();
1404
- patterns.add <AsyncOpLowering>(converter);
1405
1450
patterns.add <StreamToTokenOpLowering>(converter);
1451
+ patterns.add <AsyncOpLowering>(converter);
1406
1452
}
1407
1453
if (failed (applyPartialConversion (m, target, std::move (patterns))))
1408
1454
signalPassFailure ();
0 commit comments