diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h index 01a16e7c7b1e5..f6888d001fed6 100644 --- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h +++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h @@ -778,6 +778,7 @@ class TargetTransformInfoImplBase { case Intrinsic::experimental_gc_relocate: case Intrinsic::coro_alloc: case Intrinsic::coro_begin: + case Intrinsic::coro_begin_custom_abi: case Intrinsic::coro_free: case Intrinsic::coro_end: case Intrinsic::coro_frame: diff --git a/llvm/include/llvm/IR/Intrinsics.td b/llvm/include/llvm/IR/Intrinsics.td index 20dd921ddbd23..8a0721cf23f53 100644 --- a/llvm/include/llvm/IR/Intrinsics.td +++ b/llvm/include/llvm/IR/Intrinsics.td @@ -1719,7 +1719,8 @@ def int_coro_prepare_async : Intrinsic<[llvm_ptr_ty], [llvm_ptr_ty], [IntrNoMem]>; def int_coro_begin : Intrinsic<[llvm_ptr_ty], [llvm_token_ty, llvm_ptr_ty], [WriteOnly>]>; - +def int_coro_begin_custom_abi : Intrinsic<[llvm_ptr_ty], [llvm_token_ty, llvm_ptr_ty, llvm_i32_ty], + [WriteOnly>]>; def int_coro_free : Intrinsic<[llvm_ptr_ty], [llvm_token_ty, llvm_ptr_ty], [IntrReadMem, IntrArgMemOnly, ReadOnly>, diff --git a/llvm/include/llvm/Transforms/Coroutines/ABI.h b/llvm/include/llvm/Transforms/Coroutines/ABI.h index e7568d275c161..8b83c5308056e 100644 --- a/llvm/include/llvm/Transforms/Coroutines/ABI.h +++ b/llvm/include/llvm/Transforms/Coroutines/ABI.h @@ -29,7 +29,13 @@ namespace coro { // This interface/API is to provide an object oriented way to implement ABI // functionality. This is intended to replace use of the ABI enum to perform // ABI operations. The ABIs (e.g. Switch, Async, Retcon{Once}) are the common -// ABIs. +// ABIs. However, specific users may need to modify the behavior of these. This +// can be accomplished by inheriting one of the common ABIs and overriding one +// or more of the methods to create a custom ABI. To use a custom ABI for a +// given coroutine the coro.begin.custom.abi intrinsic is used in place of the +// coro.begin intrinsic. This takes an additional i32 arg that specifies the +// index of an ABI generator for the custom ABI object in a SmallVector passed +// to CoroSplitPass ctor. class BaseABI { public: diff --git a/llvm/include/llvm/Transforms/Coroutines/CoroInstr.h b/llvm/include/llvm/Transforms/Coroutines/CoroInstr.h index a329a06bf1389..3aa30bec85c3a 100644 --- a/llvm/include/llvm/Transforms/Coroutines/CoroInstr.h +++ b/llvm/include/llvm/Transforms/Coroutines/CoroInstr.h @@ -124,7 +124,8 @@ class AnyCoroIdInst : public IntrinsicInst { IntrinsicInst *getCoroBegin() { for (User *U : users()) if (auto *II = dyn_cast(U)) - if (II->getIntrinsicID() == Intrinsic::coro_begin) + if (II->getIntrinsicID() == Intrinsic::coro_begin || + II->getIntrinsicID() == Intrinsic::coro_begin_custom_abi) return II; llvm_unreachable("no coro.begin associated with coro.id"); } @@ -442,20 +443,30 @@ class CoroFreeInst : public IntrinsicInst { } }; -/// This class represents the llvm.coro.begin instructions. +/// This class represents the llvm.coro.begin or llvm.coro.begin.custom.abi +/// instructions. class CoroBeginInst : public IntrinsicInst { - enum { IdArg, MemArg }; + enum { IdArg, MemArg, CustomABIArg }; public: AnyCoroIdInst *getId() const { return cast(getArgOperand(IdArg)); } + bool hasCustomABI() const { + return getIntrinsicID() == Intrinsic::coro_begin_custom_abi; + } + + int getCustomABI() const { + return cast(getArgOperand(CustomABIArg))->getZExtValue(); + } + Value *getMem() const { return getArgOperand(MemArg); } // Methods for support type inquiry through isa, cast, and dyn_cast: static bool classof(const IntrinsicInst *I) { - return I->getIntrinsicID() == Intrinsic::coro_begin; + return I->getIntrinsicID() == Intrinsic::coro_begin || + I->getIntrinsicID() == Intrinsic::coro_begin_custom_abi; } static bool classof(const Value *V) { return isa(V) && classof(cast(V)); diff --git a/llvm/include/llvm/Transforms/Coroutines/CoroSplit.h b/llvm/include/llvm/Transforms/Coroutines/CoroSplit.h index a5fd57f8f9dfa..6c6a982e82805 100644 --- a/llvm/include/llvm/Transforms/Coroutines/CoroSplit.h +++ b/llvm/include/llvm/Transforms/Coroutines/CoroSplit.h @@ -28,17 +28,26 @@ struct Shape; } // namespace coro struct CoroSplitPass : PassInfoMixin { + using BaseABITy = + std::function(Function &, coro::Shape &)>; CoroSplitPass(bool OptimizeFrame = false); + + CoroSplitPass(SmallVector GenCustomABIs, + bool OptimizeFrame = false); + CoroSplitPass(std::function MaterializableCallback, bool OptimizeFrame = false); + CoroSplitPass(std::function MaterializableCallback, + SmallVector GenCustomABIs, + bool OptimizeFrame = false); + PreservedAnalyses run(LazyCallGraph::SCC &C, CGSCCAnalysisManager &AM, LazyCallGraph &CG, CGSCCUpdateResult &UR); + static bool isRequired() { return true; } - using BaseABITy = - std::function(Function &, coro::Shape &)>; // Generator for an ABI transformer BaseABITy CreateAndInitABI; diff --git a/llvm/lib/Transforms/Coroutines/CoroCleanup.cpp b/llvm/lib/Transforms/Coroutines/CoroCleanup.cpp index dd92b3593af92..1cda7f93f72a2 100644 --- a/llvm/lib/Transforms/Coroutines/CoroCleanup.cpp +++ b/llvm/lib/Transforms/Coroutines/CoroCleanup.cpp @@ -53,6 +53,7 @@ bool Lowerer::lower(Function &F) { default: continue; case Intrinsic::coro_begin: + case Intrinsic::coro_begin_custom_abi: II->replaceAllUsesWith(II->getArgOperand(1)); break; case Intrinsic::coro_free: @@ -112,7 +113,8 @@ static bool declaresCoroCleanupIntrinsics(const Module &M) { M, {"llvm.coro.alloc", "llvm.coro.begin", "llvm.coro.subfn.addr", "llvm.coro.free", "llvm.coro.id", "llvm.coro.id.retcon", "llvm.coro.id.async", "llvm.coro.id.retcon.once", - "llvm.coro.async.size.replace", "llvm.coro.async.resume"}); + "llvm.coro.async.size.replace", "llvm.coro.async.resume", + "llvm.coro.begin.custom.abi"}); } PreservedAnalyses CoroCleanupPass::run(Module &M, diff --git a/llvm/lib/Transforms/Coroutines/CoroSplit.cpp b/llvm/lib/Transforms/Coroutines/CoroSplit.cpp index ef1f27118bc14..88ce331c8cfb6 100644 --- a/llvm/lib/Transforms/Coroutines/CoroSplit.cpp +++ b/llvm/lib/Transforms/Coroutines/CoroSplit.cpp @@ -2200,7 +2200,15 @@ static void addPrepareFunction(const Module &M, static std::unique_ptr CreateNewABI(Function &F, coro::Shape &S, - std::function IsMatCallback) { + std::function IsMatCallback, + const SmallVector GenCustomABIs) { + if (S.CoroBegin->hasCustomABI()) { + unsigned CustomABI = S.CoroBegin->getCustomABI(); + if (CustomABI >= GenCustomABIs.size()) + llvm_unreachable("Custom ABI not found amoung those specified"); + return GenCustomABIs[CustomABI](F, S); + } + switch (S.ABI) { case coro::ABI::Switch: return std::unique_ptr( @@ -2221,7 +2229,17 @@ CreateNewABI(Function &F, coro::Shape &S, CoroSplitPass::CoroSplitPass(bool OptimizeFrame) : CreateAndInitABI([](Function &F, coro::Shape &S) { std::unique_ptr ABI = - CreateNewABI(F, S, coro::isTriviallyMaterializable); + CreateNewABI(F, S, coro::isTriviallyMaterializable, {}); + ABI->init(); + return ABI; + }), + OptimizeFrame(OptimizeFrame) {} + +CoroSplitPass::CoroSplitPass( + SmallVector GenCustomABIs, bool OptimizeFrame) + : CreateAndInitABI([=](Function &F, coro::Shape &S) { + std::unique_ptr ABI = + CreateNewABI(F, S, coro::isTriviallyMaterializable, GenCustomABIs); ABI->init(); return ABI; }), @@ -2232,7 +2250,21 @@ CoroSplitPass::CoroSplitPass(bool OptimizeFrame) CoroSplitPass::CoroSplitPass(std::function IsMatCallback, bool OptimizeFrame) : CreateAndInitABI([=](Function &F, coro::Shape &S) { - std::unique_ptr ABI = CreateNewABI(F, S, IsMatCallback); + std::unique_ptr ABI = + CreateNewABI(F, S, IsMatCallback, {}); + ABI->init(); + return ABI; + }), + OptimizeFrame(OptimizeFrame) {} + +// For back compatibility, constructor takes a materializable callback and +// creates a generator for an ABI with a modified materializable callback. +CoroSplitPass::CoroSplitPass( + std::function IsMatCallback, + SmallVector GenCustomABIs, bool OptimizeFrame) + : CreateAndInitABI([=](Function &F, coro::Shape &S) { + std::unique_ptr ABI = + CreateNewABI(F, S, IsMatCallback, GenCustomABIs); ABI->init(); return ABI; }), diff --git a/llvm/lib/Transforms/Coroutines/Coroutines.cpp b/llvm/lib/Transforms/Coroutines/Coroutines.cpp index f4d9a7a8aa856..1c45bcd7f6a83 100644 --- a/llvm/lib/Transforms/Coroutines/Coroutines.cpp +++ b/llvm/lib/Transforms/Coroutines/Coroutines.cpp @@ -73,6 +73,7 @@ static const char *const CoroIntrinsics[] = { "llvm.coro.await.suspend.handle", "llvm.coro.await.suspend.void", "llvm.coro.begin", + "llvm.coro.begin.custom.abi", "llvm.coro.destroy", "llvm.coro.done", "llvm.coro.end", @@ -247,7 +248,8 @@ void coro::Shape::analyze(Function &F, } break; } - case Intrinsic::coro_begin: { + case Intrinsic::coro_begin: + case Intrinsic::coro_begin_custom_abi: { auto CB = cast(II); // Ignore coro id's that aren't pre-split. diff --git a/llvm/unittests/Transforms/Coroutines/ExtraRematTest.cpp b/llvm/unittests/Transforms/Coroutines/ExtraRematTest.cpp index 1d55889a32d7a..c3394fdaa940b 100644 --- a/llvm/unittests/Transforms/Coroutines/ExtraRematTest.cpp +++ b/llvm/unittests/Transforms/Coroutines/ExtraRematTest.cpp @@ -182,4 +182,91 @@ TEST_F(ExtraRematTest, TestCoroRematWithCallback) { CallInst *CI = getCallByName(Resume1, "should.remat"); ASSERT_TRUE(CI); } + +StringRef TextCoroBeginCustomABI = R"( + define ptr @f(i32 %n) presplitcoroutine { + entry: + %id = call token @llvm.coro.id(i32 0, ptr null, ptr null, ptr null) + %size = call i32 @llvm.coro.size.i32() + %alloc = call ptr @malloc(i32 %size) + %hdl = call ptr @llvm.coro.begin.custom.abi(token %id, ptr %alloc, i32 0) + + %inc1 = add i32 %n, 1 + %val2 = call i32 @should.remat(i32 %inc1) + %sp1 = call i8 @llvm.coro.suspend(token none, i1 false) + switch i8 %sp1, label %suspend [i8 0, label %resume1 + i8 1, label %cleanup] + resume1: + %inc2 = add i32 %val2, 1 + %sp2 = call i8 @llvm.coro.suspend(token none, i1 false) + switch i8 %sp1, label %suspend [i8 0, label %resume2 + i8 1, label %cleanup] + + resume2: + call void @print(i32 %val2) + call void @print(i32 %inc2) + br label %cleanup + + cleanup: + %mem = call ptr @llvm.coro.free(token %id, ptr %hdl) + call void @free(ptr %mem) + br label %suspend + suspend: + call i1 @llvm.coro.end(ptr %hdl, i1 0) + ret ptr %hdl + } + + declare ptr @llvm.coro.free(token, ptr) + declare i32 @llvm.coro.size.i32() + declare i8 @llvm.coro.suspend(token, i1) + declare void @llvm.coro.resume(ptr) + declare void @llvm.coro.destroy(ptr) + + declare token @llvm.coro.id(i32, ptr, ptr, ptr) + declare i1 @llvm.coro.alloc(token) + declare ptr @llvm.coro.begin.custom.abi(token, ptr, i32) + declare i1 @llvm.coro.end(ptr, i1) + + declare i32 @should.remat(i32) + + declare noalias ptr @malloc(i32) + declare void @print(i32) + declare void @free(ptr) + )"; + +// SwitchABI with overridden isMaterializable +class ExtraCustomABI : public coro::SwitchABI { +public: + ExtraCustomABI(Function &F, coro::Shape &S) + : coro::SwitchABI(F, S, ExtraMaterializable) {} +}; + +TEST_F(ExtraRematTest, TestCoroRematWithCustomABI) { + ParseAssembly(TextCoroBeginCustomABI); + + ASSERT_TRUE(M); + + CoroSplitPass::BaseABITy GenCustomABI = [](Function &F, coro::Shape &S) { + return std::unique_ptr(new ExtraCustomABI(F, S)); + }; + + CGSCCPassManager CGPM; + CGPM.addPass(CoroSplitPass({GenCustomABI})); + MPM.addPass(createModuleToPostOrderCGSCCPassAdaptor(std::move(CGPM))); + MPM.run(*M, MAM); + + // Verify that extra rematerializable instruction has been rematerialized + Function *F = M->getFunction("f.resume"); + ASSERT_TRUE(F) << "could not find split function f.resume"; + + BasicBlock *Resume1 = getBasicBlockByName(F, "resume1"); + ASSERT_TRUE(Resume1) + << "could not find expected BB resume1 in split function"; + + // With callback the extra rematerialization of the function should have + // happened + CallInst *CI = getCallByName(Resume1, "should.remat"); + ASSERT_TRUE(CI); +} + } // namespace