Skip to content

Commit bcd804b

Browse files
authored
Fix async lowering for exec body that lowers to multiple blocks (#302)
* Fix async lowering for exec body that lowers to multiple blocks * format * Remove prints * test * comment * Fix not using rewriter to insert op * Handle unlowered memref alloca scopes
1 parent 65bc29b commit bcd804b

File tree

3 files changed

+215
-121
lines changed

3 files changed

+215
-121
lines changed

lib/polygeist/Ops.cpp

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2435,8 +2435,9 @@ struct AggressiveAllocaScopeInliner
24352435

24362436
LogicalResult matchAndRewrite(memref::AllocaScopeOp op,
24372437
PatternRewriter &rewriter) const override {
2438-
bool hasPotentialAlloca =
2439-
op->walk<WalkOrder::PreOrder>([&](Operation *alloc) {
2438+
auto hasPotentialAlloca = [&]() {
2439+
return op
2440+
->walk<WalkOrder::PreOrder>([&](Operation *alloc) {
24402441
if (alloc == op || isa<LLVM::CallOp>(alloc) ||
24412442
isa<func::CallOp>(alloc) || isa<omp::BarrierOp>(alloc) ||
24422443
isa<polygeist::BarrierOp>(alloc))
@@ -2446,17 +2447,20 @@ struct AggressiveAllocaScopeInliner
24462447
if (alloc->hasTrait<OpTrait::AutomaticAllocationScope>())
24472448
return WalkResult::skip();
24482449
return WalkResult::advance();
2449-
}).wasInterrupted();
2450-
2451-
// If this contains no potential allocation, it is always legal to
2452-
// inline. Otherwise, consider two conditions:
2453-
if (hasPotentialAlloca) {
2454-
// If the parent isn't an allocation scope, or we are not the last
2455-
// non-terminator op in the parent, we will extend the lifetime.
2456-
if (!op->getParentOp()->hasTrait<OpTrait::AutomaticAllocationScope>())
2457-
return failure();
2458-
// if (!lastNonTerminatorInRegion(op))
2459-
// return failure();
2450+
})
2451+
.wasInterrupted();
2452+
};
2453+
2454+
// If we are the not last operation in the block
2455+
if (op->getNextNode() != op->getBlock()->getTerminator()) {
2456+
// If this contains no potential allocation, it is always legal to
2457+
// inline. Otherwise, consider two conditions:
2458+
if (hasPotentialAlloca()) {
2459+
// If the parent isn't an allocation scope, or we are not the last
2460+
// non-terminator op in the parent, we will extend the lifetime.
2461+
if (!op->getParentOp()->hasTrait<OpTrait::AutomaticAllocationScope>())
2462+
return failure();
2463+
}
24602464
}
24612465

24622466
Block *block = &op.getRegion().front();

lib/polygeist/Passes/ConvertPolygeistToLLVM.cpp

Lines changed: 109 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
//===----------------------------------------------------------------------===//
1212
#include "PassDetails.h"
1313

14+
#include "mlir/../../lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp"
1415
#include "mlir/Analysis/DataLayoutAnalysis.h"
1516
#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
1617
#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
@@ -350,12 +351,13 @@ struct LLVMOpLowering : public ConversionPattern {
350351
state.addRegion();
351352

352353
Operation *rewritten = rewriter.create(state);
353-
rewriter.replaceOp(op, rewritten->getResults());
354354

355355
for (unsigned i = 0, e = op->getNumRegions(); i < e; ++i)
356356
rewriter.inlineRegionBefore(op->getRegion(i), rewritten->getRegion(i),
357357
rewritten->getRegion(i).begin());
358358

359+
rewriter.replaceOp(op, rewritten->getResults());
360+
359361
return success();
360362
}
361363
};
@@ -407,6 +409,35 @@ static LLVM::LLVMFuncOp addMocCUDAFunction(ModuleOp module, Type streamTy) {
407409
return resumeOp;
408410
}
409411

412+
/// In some cases such as scf.for, the blocks generated when it gets lowered
413+
/// depend on the parent region having already been lowered and having a
414+
/// converter assigned to it - this pattern assures that execute ops have a
415+
/// converter becaus they will actually be lowered only after everything else
416+
/// has been converted to llvm
417+
class ConvertExecuteOpTypes : public ConvertOpToLLVMPattern<async::ExecuteOp> {
418+
public:
419+
using ConvertOpToLLVMPattern<async::ExecuteOp>::ConvertOpToLLVMPattern;
420+
LogicalResult
421+
matchAndRewrite(async::ExecuteOp op, OpAdaptor adaptor,
422+
ConversionPatternRewriter &rewriter) const override {
423+
async::ExecuteOp newOp = cast<async::ExecuteOp>(
424+
rewriter.cloneWithoutRegions(*op.getOperation()));
425+
rewriter.inlineRegionBefore(op.getRegion(), newOp.getRegion(),
426+
newOp.getRegion().end());
427+
428+
// Set operands and update block argument and result types.
429+
newOp->setOperands(adaptor.getOperands());
430+
if (failed(rewriter.convertRegionTypes(&newOp.getRegion(), *typeConverter)))
431+
return failure();
432+
for (auto result : newOp.getResults())
433+
result.setType(typeConverter->convertType(result.getType()));
434+
435+
newOp->setAttr("polygeist.handled", rewriter.getUnitAttr());
436+
rewriter.replaceOp(op, newOp.getResults());
437+
return success();
438+
}
439+
};
440+
410441
struct AsyncOpLowering : public ConvertOpToLLVMPattern<async::ExecuteOp> {
411442
using ConvertOpToLLVMPattern<async::ExecuteOp>::ConvertOpToLLVMPattern;
412443

@@ -423,12 +454,12 @@ struct AsyncOpLowering : public ConvertOpToLLVMPattern<async::ExecuteOp> {
423454

424455
// Make sure that all constants will be inside the outlined async function
425456
// to reduce the number of function arguments.
426-
Region &funcReg = execute.getBodyRegion();
457+
Region &execReg = execute.getBodyRegion();
427458

428459
// Collect all outlined function inputs.
429460
SetVector<mlir::Value> functionInputs;
430461

431-
getUsedValuesDefinedAbove(execute.getBodyRegion(), funcReg, functionInputs);
462+
getUsedValuesDefinedAbove(execute.getBodyRegion(), execReg, functionInputs);
432463
SmallVector<Value> toErase;
433464
for (auto a : functionInputs) {
434465
Operation *op = a.getDefiningOp();
@@ -451,16 +482,18 @@ struct AsyncOpLowering : public ConvertOpToLLVMPattern<async::ExecuteOp> {
451482

452483
// TODO: Derive outlined function name from the parent FuncOp (support
453484
// multiple nested async.execute operations).
454-
auto moduleBuilder =
455-
ImplicitLocOpBuilder::atBlockEnd(loc, module.getBody());
456-
457-
static int off = 0;
458-
off++;
459-
auto func = moduleBuilder.create<LLVM::LLVMFuncOp>(
460-
execute.getLoc(),
461-
"kernelbody." + std::to_string((long long int)&execute) + "." +
462-
std::to_string(off),
463-
funcType);
485+
LLVM::LLVMFuncOp func;
486+
{
487+
OpBuilder::InsertionGuard guard(rewriter);
488+
rewriter.setInsertionPointToEnd(module.getBody());
489+
static int off = 0;
490+
off++;
491+
func = rewriter.create<LLVM::LLVMFuncOp>(
492+
execute.getLoc(),
493+
"kernelbody." + std::to_string((long long int)&execute) + "." +
494+
std::to_string(off),
495+
funcType);
496+
}
464497

465498
rewriter.setInsertionPointToStart(func.addEntryBlock());
466499
BlockAndValueMapping valueMapping;
@@ -522,10 +555,17 @@ struct AsyncOpLowering : public ConvertOpToLLVMPattern<async::ExecuteOp> {
522555

523556
// Clone all operations from the execute operation body into the outlined
524557
// function body.
525-
for (Operation &op : execute.getBody()->without_terminator())
526-
rewriter.clone(op, valueMapping);
527-
528-
rewriter.create<LLVM::ReturnOp>(execute.getLoc(), ValueRange());
558+
rewriter.cloneRegionBefore(execute.getBodyRegion(), func.getRegion(),
559+
func.getRegion().end(), valueMapping);
560+
rewriter.create<LLVM::BrOp>(execute.getLoc(), ValueRange(),
561+
&*std::next(func.getRegion().begin()));
562+
for (Block &b : func.getRegion()) {
563+
auto term = b.getTerminator();
564+
if (isa<async::YieldOp>(term)) {
565+
rewriter.setInsertionPointToEnd(&b);
566+
rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(term, ValueRange());
567+
}
568+
}
529569
}
530570

531571
// Replace the original `async.execute` with a call to outlined function.
@@ -703,7 +743,7 @@ struct AllocLikeOpLowering : public ConvertOpToLLVMPattern<OpTy> {
703743
};
704744

705745
/// Pattern for lowering automatic stack allocations.
706-
struct AllocaOpLowering : public AllocLikeOpLowering<memref::AllocaOp> {
746+
struct CAllocaOpLowering : public AllocLikeOpLowering<memref::AllocaOp> {
707747
public:
708748
using AllocLikeOpLowering<memref::AllocaOp>::AllocLikeOpLowering;
709749

@@ -729,7 +769,7 @@ struct AllocaOpLowering : public AllocLikeOpLowering<memref::AllocaOp> {
729769
};
730770

731771
/// Pattern for lowering heap allocations via malloc.
732-
struct AllocOpLowering : public AllocLikeOpLowering<memref::AllocOp> {
772+
struct CAllocOpLowering : public AllocLikeOpLowering<memref::AllocOp> {
733773
public:
734774
using AllocLikeOpLowering<memref::AllocOp>::AllocLikeOpLowering;
735775

@@ -783,7 +823,7 @@ struct AllocOpLowering : public AllocLikeOpLowering<memref::AllocOp> {
783823
};
784824

785825
/// Pattern for lowering heap deallocations via free.
786-
struct DeallocOpLowering : public ConvertOpToLLVMPattern<memref::DeallocOp> {
826+
struct CDeallocOpLowering : public ConvertOpToLLVMPattern<memref::DeallocOp> {
787827
public:
788828
using ConvertOpToLLVMPattern<memref::DeallocOp>::ConvertOpToLLVMPattern;
789829

@@ -914,7 +954,7 @@ struct GetGlobalOpLowering
914954

915955
/// Base class for patterns lowering memory access operations.
916956
template <typename OpTy>
917-
struct LoadStoreOpLowering : public ConvertOpToLLVMPattern<OpTy> {
957+
struct CLoadStoreOpLowering : public ConvertOpToLLVMPattern<OpTy> {
918958
protected:
919959
using ConvertOpToLLVMPattern<OpTy>::ConvertOpToLLVMPattern;
920960

@@ -941,9 +981,9 @@ struct LoadStoreOpLowering : public ConvertOpToLLVMPattern<OpTy> {
941981
};
942982

943983
/// Pattern for lowering a memory load.
944-
struct LoadOpLowering : public LoadStoreOpLowering<memref::LoadOp> {
984+
struct CLoadOpLowering : public CLoadStoreOpLowering<memref::LoadOp> {
945985
public:
946-
using LoadStoreOpLowering<memref::LoadOp>::LoadStoreOpLowering;
986+
using CLoadStoreOpLowering<memref::LoadOp>::CLoadStoreOpLowering;
947987

948988
LogicalResult
949989
matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
@@ -958,9 +998,9 @@ struct LoadOpLowering : public LoadStoreOpLowering<memref::LoadOp> {
958998
};
959999

9601000
/// Pattern for lowering a memory store.
961-
struct StoreOpLowering : public LoadStoreOpLowering<memref::StoreOp> {
1001+
struct CStoreOpLowering : public CLoadStoreOpLowering<memref::StoreOp> {
9621002
public:
963-
using LoadStoreOpLowering<memref::StoreOp>::LoadStoreOpLowering;
1003+
using CLoadStoreOpLowering<memref::StoreOp>::CLoadStoreOpLowering;
9641004

9651005
LogicalResult
9661006
matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
@@ -1242,9 +1282,9 @@ struct ReturnOpLowering : public ConvertOpToLLVMPattern<func::ReturnOp> {
12421282
static void
12431283
populateCStyleMemRefLoweringPatterns(RewritePatternSet &patterns,
12441284
LLVMTypeConverter &typeConverter) {
1245-
patterns.add<AllocaOpLowering, AllocOpLowering, DeallocOpLowering,
1246-
GetGlobalOpLowering, GlobalOpLowering, LoadOpLowering,
1247-
StoreOpLowering>(typeConverter);
1285+
patterns.add<CAllocaOpLowering, CAllocOpLowering, CDeallocOpLowering,
1286+
GetGlobalOpLowering, GlobalOpLowering, CLoadOpLowering,
1287+
CStoreOpLowering, AllocaScopeOpLowering>(typeConverter);
12481288
}
12491289

12501290
/// Appends the patterns lowering operations from the Func dialect to the LLVM
@@ -1292,40 +1332,42 @@ struct ConvertPolygeistToLLVMPass
12921332

12931333
options.dataLayout = llvm::DataLayout(this->dataLayout);
12941334

1295-
for (int i = 0; i < 2; i++) {
1335+
// Define the type converter. Override the default behavior for memrefs if
1336+
// requested.
1337+
LLVMTypeConverter converter(&getContext(), options, &dataLayoutAnalysis);
1338+
if (useCStyleMemRef) {
1339+
converter.addConversion([&](MemRefType type) -> Optional<Type> {
1340+
Type converted = converter.convertType(type.getElementType());
1341+
if (!converted)
1342+
return Type();
12961343

1297-
// Define the type converter. Override the default behavior for memrefs if
1298-
// requested.
1299-
LLVMTypeConverter converter(&getContext(), options, &dataLayoutAnalysis);
1300-
if (useCStyleMemRef) {
1301-
converter.addConversion([&](MemRefType type) -> Optional<Type> {
1302-
Type converted = converter.convertType(type.getElementType());
1303-
if (!converted)
1304-
return Type();
1305-
1306-
if (type.getRank() == 0) {
1307-
return LLVM::LLVMPointerType::get(converted,
1308-
type.getMemorySpaceAsInt());
1309-
}
1310-
1311-
// Only the leading dimension can be dynamic.
1312-
if (llvm::any_of(type.getShape().drop_front(), ShapedType::isDynamic))
1313-
return Type();
1314-
1315-
// Only identity layout is supported.
1316-
// TODO: detect the strided layout that is equivalent to identity
1317-
// given the static part of the shape.
1318-
if (!type.getLayout().isIdentity())
1319-
return Type();
1320-
1321-
if (type.getRank() > 0) {
1322-
for (int64_t size : llvm::reverse(type.getShape().drop_front()))
1323-
converted = LLVM::LLVMArrayType::get(converted, size);
1324-
}
1344+
if (type.getRank() == 0) {
13251345
return LLVM::LLVMPointerType::get(converted,
13261346
type.getMemorySpaceAsInt());
1327-
});
1328-
}
1347+
}
1348+
1349+
// Only the leading dimension can be dynamic.
1350+
if (llvm::any_of(type.getShape().drop_front(), ShapedType::isDynamic))
1351+
return Type();
1352+
1353+
// Only identity layout is supported.
1354+
// TODO: detect the strided layout that is equivalent to identity
1355+
// given the static part of the shape.
1356+
if (!type.getLayout().isIdentity())
1357+
return Type();
1358+
1359+
if (type.getRank() > 0) {
1360+
for (int64_t size : llvm::reverse(type.getShape().drop_front()))
1361+
converted = LLVM::LLVMArrayType::get(converted, size);
1362+
}
1363+
return LLVM::LLVMPointerType::get(converted,
1364+
type.getMemorySpaceAsInt());
1365+
});
1366+
}
1367+
1368+
converter.addConversion([&](async::TokenType type) { return type; });
1369+
1370+
for (int i = 0; i < 2; i++) {
13291371

13301372
RewritePatternSet patterns(&getContext());
13311373
populatePolygeistToLLVMConversionPatterns(converter, patterns);
@@ -1343,8 +1385,6 @@ struct ConvertPolygeistToLLVMPass
13431385
populateOpenMPToLLVMConversionPatterns(converter, patterns);
13441386
arith::populateArithToLLVMConversionPatterns(converter, patterns);
13451387

1346-
converter.addConversion([&](async::TokenType type) { return type; });
1347-
13481388
patterns.add<LLVMOpLowering, GlobalOpTypeConversion,
13491389
ReturnOpTypeConversion, GetFuncOpConversion>(converter);
13501390
patterns.add<URLLVMOpLowering>(converter);
@@ -1399,10 +1439,16 @@ struct ConvertPolygeistToLLVMPass
13991439
op->getResult(0).getType(); });
14001440
*/
14011441

1402-
if (i == 1) {
1442+
if (i == 0) {
1443+
patterns.add<ConvertExecuteOpTypes>(converter);
1444+
target.addDynamicallyLegalOp<async::ExecuteOp>(
1445+
[&](async::ExecuteOp eo) {
1446+
return eo->hasAttr("polygeist.handled");
1447+
});
1448+
} else if (i == 1) {
14031449
// target.addIllegalOp<UnrealizedConversionCastOp>();
1404-
patterns.add<AsyncOpLowering>(converter);
14051450
patterns.add<StreamToTokenOpLowering>(converter);
1451+
patterns.add<AsyncOpLowering>(converter);
14061452
}
14071453
if (failed(applyPartialConversion(m, target, std::move(patterns))))
14081454
signalPassFailure();

0 commit comments

Comments
 (0)