diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h index 6a657724dc611..0b8f22719faf1 100644 --- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h +++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h @@ -31,6 +31,7 @@ namespace llvm { class CanonicalLoopInfo; +class CodeExtractor; class ScanInfo; struct TargetRegionEntryInfo; class OffloadEntriesInfoManager; @@ -2264,17 +2265,27 @@ class OpenMPIRBuilder { BasicBlock *EntryBB, *ExitBB, *OuterAllocaBB; SmallVector ExcludeArgsFromAggregate; + LLVM_ABI virtual ~OutlineInfo() = default; + /// Collect all blocks in between EntryBB and ExitBB in both the given /// vector and set. LLVM_ABI void collectBlocks(SmallPtrSetImpl &BlockSet, SmallVectorImpl &BlockVector); + /// Create a CodeExtractor instance based on the information stored in this + /// structure, the list of collected blocks from a previous call to + /// \c collectBlocks and a flag stating whether arguments must be passed in + /// address space 0. + LLVM_ABI virtual std::unique_ptr + createCodeExtractor(ArrayRef Blocks, + bool ArgsInZeroAddressSpace, Twine Suffix = Twine("")); + /// Return the function that contains the region to be outlined. Function *getFunction() const { return EntryBB->getParent(); } }; /// Collection of regions that need to be outlined during finalization. - SmallVector OutlineInfos; + SmallVector, 16> OutlineInfos; /// A collection of candidate target functions that's constant allocas will /// attempt to be raised on a call of finalize after all currently enqueued @@ -2289,7 +2300,9 @@ class OpenMPIRBuilder { std::forward_list ScanInfos; /// Add a new region that will be outlined later. - void addOutlineInfo(OutlineInfo &&OI) { OutlineInfos.emplace_back(OI); } + void addOutlineInfo(std::unique_ptr &&OI) { + OutlineInfos.emplace_back(std::move(OI)); + } /// An ordered map of auto-generated variables to their unique names. /// It stores variables with the following names: 1) ".gomp_critical_user_" + diff --git a/llvm/include/llvm/Transforms/Utils/CodeExtractor.h b/llvm/include/llvm/Transforms/Utils/CodeExtractor.h index 407eb50d2c7a3..b3bea96039172 100644 --- a/llvm/include/llvm/Transforms/Utils/CodeExtractor.h +++ b/llvm/include/llvm/Transforms/Utils/CodeExtractor.h @@ -17,14 +17,15 @@ #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/SetVector.h" +#include "llvm/IR/BasicBlock.h" #include "llvm/Support/Compiler.h" #include namespace llvm { template class SmallPtrSetImpl; +class AddrSpaceCastInst; class AllocaInst; -class BasicBlock; class BlockFrequency; class BlockFrequencyInfo; class BranchProbabilityInfo; @@ -94,15 +95,23 @@ class CodeExtractorAnalysisCache { BranchProbabilityInfo *BPI; AssumptionCache *AC; - // A block outside of the extraction set where any intermediate - // allocations will be placed inside. If this is null, allocations - // will be placed in the entry block of the function. + /// A block outside of the extraction set where any intermediate + /// allocations will be placed inside. If this is null, allocations + /// will be placed in the entry block of the function. BasicBlock *AllocationBlock; - // If true, varargs functions can be extracted. + /// A block outside of the extraction set where deallocations for + /// intermediate allocations can be placed inside. Not used for + /// automatically deallocated memory (e.g. `alloca`), which is the default. + /// + /// If it is null and needed, the end of the replacement basic block will be + /// used to place deallocations. + BasicBlock *DeallocationBlock; + + /// If true, varargs functions can be extracted. bool AllowVarArgs; - // Bits of intermediate state computed at various phases of extraction. + /// Bits of intermediate state computed at various phases of extraction. SetVector Blocks; /// Lists of blocks that are branched from the code region to be extracted, @@ -124,13 +133,13 @@ class CodeExtractorAnalysisCache { /// returns 1, etc. SmallVector ExtractedFuncRetVals; - // Suffix to use when creating extracted function (appended to the original - // function name + "."). If empty, the default is to use the entry block - // label, if non-empty, otherwise "extracted". + /// Suffix to use when creating extracted function (appended to the original + /// function name + "."). If empty, the default is to use the entry block + /// label, if non-empty, otherwise "extracted". std::string Suffix; - // If true, the outlined function has aggregate argument in zero address - // space. + /// If true, the outlined function has aggregate argument in zero address + /// space. bool ArgsInZeroAddressSpace; public: @@ -146,7 +155,9 @@ class CodeExtractorAnalysisCache { /// however code extractor won't validate whether extraction is legal. /// Any new allocations will be placed in the AllocationBlock, unless /// it is null, in which case it will be placed in the entry block of - /// the function from which the code is being extracted. + /// the function from which the code is being extracted. Explicit + /// deallocations for the aforementioned allocations will be placed in the + /// DeallocationBlock or the end of the replacement block, if needed. /// If ArgsInZeroAddressSpace param is set to true, then the aggregate /// param pointer of the outlined function is declared in zero address /// space. @@ -157,8 +168,11 @@ class CodeExtractorAnalysisCache { AssumptionCache *AC = nullptr, bool AllowVarArgs = false, bool AllowAlloca = false, BasicBlock *AllocationBlock = nullptr, + BasicBlock *DeallocationBlock = nullptr, std::string Suffix = "", bool ArgsInZeroAddressSpace = false); + LLVM_ABI virtual ~CodeExtractor() = default; + /// Perform the extraction, returning the new function. /// /// Returns zero when called on a CodeExtractor instance where isEligible @@ -243,6 +257,19 @@ class CodeExtractorAnalysisCache { /// region, passing it instead as a scalar. LLVM_ABI void excludeArgFromAggregate(Value *Arg); + protected: + /// Allocate an intermediate variable at the specified point. + LLVM_ABI virtual Instruction * + allocateVar(BasicBlock *BB, BasicBlock::iterator AllocIP, Type *VarType, + const Twine &Name = Twine(""), + AddrSpaceCastInst **CastedAlloc = nullptr); + + /// Deallocate a previously-allocated intermediate variable at the specified + /// point. + LLVM_ABI virtual Instruction *deallocateVar(BasicBlock *BB, + BasicBlock::iterator DeallocIP, + Value *Var, Type *VarType); + private: struct LifetimeMarkerInfo { bool SinkLifeStart = false; diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp index ab2a059e423c1..e0b7378c34f77 100644 --- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp +++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp @@ -280,6 +280,38 @@ computeOpenMPScheduleType(ScheduleKind ClauseKind, bool HasChunks, return Result; } +/// Given a function, if it represents the entry point of a target kernel, this +/// returns the execution mode flags associated with that kernel. +static std::optional +getTargetKernelExecMode(Function &Kernel) { + CallInst *TargetInitCall = nullptr; + for (Instruction &Inst : Kernel.getEntryBlock()) { + if (auto *Call = dyn_cast(&Inst)) { + if (Call->getCalledFunction()->getName() == "__kmpc_target_init") { + TargetInitCall = Call; + break; + } + } + } + + if (!TargetInitCall) + return std::nullopt; + + // Get the kernel mode information from the global variable associated to the + // first argument to the call to __kmpc_target_init. Refer to + // createTargetInit() to see how this is initialized. + Value *InitOperand = TargetInitCall->getArgOperand(0); + GlobalVariable *KernelEnv = nullptr; + if (auto *Cast = dyn_cast(InitOperand)) + KernelEnv = cast(Cast->getOperand(0)); + else + KernelEnv = cast(InitOperand); + auto *KernelEnvInit = cast(KernelEnv->getInitializer()); + auto *ConfigEnv = cast(KernelEnvInit->getOperand(0)); + auto *KernelMode = cast(ConfigEnv->getOperand(2)); + return static_cast(KernelMode->getZExtValue()); +} + /// Make \p Source branch to \p Target. /// /// Handles two situations: @@ -447,6 +479,88 @@ enum OpenMPOffloadingRequiresDirFlags { LLVM_MARK_AS_BITMASK_ENUM(/*LargestValue=*/OMP_REQ_DYNAMIC_ALLOCATORS) }; +class OMPCodeExtractor : public CodeExtractor { +public: + OMPCodeExtractor(OpenMPIRBuilder &OMPBuilder, ArrayRef BBs, + DominatorTree *DT = nullptr, bool AggregateArgs = false, + BlockFrequencyInfo *BFI = nullptr, + BranchProbabilityInfo *BPI = nullptr, + AssumptionCache *AC = nullptr, bool AllowVarArgs = false, + bool AllowAlloca = false, + BasicBlock *AllocationBlock = nullptr, + BasicBlock *DeallocationBlock = nullptr, + std::string Suffix = "", bool ArgsInZeroAddressSpace = false) + : CodeExtractor(BBs, DT, AggregateArgs, BFI, BPI, AC, AllowVarArgs, + AllowAlloca, AllocationBlock, DeallocationBlock, Suffix, + ArgsInZeroAddressSpace), + OMPBuilder(OMPBuilder) {} + + virtual ~OMPCodeExtractor() = default; + +protected: + OpenMPIRBuilder &OMPBuilder; +}; + +class DeviceSharedMemCodeExtractor : public OMPCodeExtractor { +public: + DeviceSharedMemCodeExtractor( + OpenMPIRBuilder &OMPBuilder, BasicBlock *AllocBlockOverride, + ArrayRef BBs, DominatorTree *DT = nullptr, + bool AggregateArgs = false, BlockFrequencyInfo *BFI = nullptr, + BranchProbabilityInfo *BPI = nullptr, AssumptionCache *AC = nullptr, + bool AllowVarArgs = false, bool AllowAlloca = false, + BasicBlock *AllocationBlock = nullptr, + BasicBlock *DeallocationBlock = nullptr, std::string Suffix = "", + bool ArgsInZeroAddressSpace = false) + : OMPCodeExtractor(OMPBuilder, BBs, DT, AggregateArgs, BFI, BPI, AC, + AllowVarArgs, AllowAlloca, AllocationBlock, + DeallocationBlock, Suffix, ArgsInZeroAddressSpace), + AllocBlockOverride(AllocBlockOverride) {} + virtual ~DeviceSharedMemCodeExtractor() = default; + +protected: + virtual Instruction * + allocateVar(BasicBlock *, BasicBlock::iterator, Type *VarType, + const Twine &Name = Twine(""), + AddrSpaceCastInst **CastedAlloc = nullptr) override { + // Ignore the CastedAlloc pointer, if requested, because shared memory + // should not be casted to address space 0 to be passed around. + return OMPBuilder.createOMPAllocShared( + OpenMPIRBuilder::InsertPointTy( + AllocBlockOverride, AllocBlockOverride->getFirstInsertionPt()), + VarType, Name); + } + + virtual Instruction *deallocateVar(BasicBlock *BB, + BasicBlock::iterator DeallocIP, Value *Var, + Type *VarType) override { + return OMPBuilder.createOMPFreeShared( + OpenMPIRBuilder::InsertPointTy(BB, DeallocIP), Var, VarType); + } + +private: + // TODO: Remove the need for this override and instead get the CodeExtractor + // to provide a valid insert point for explicit deallocations by correctly + // populating its DeallocationBlock. + BasicBlock *AllocBlockOverride; +}; + +/// Helper storing information about regions to outline using device shared +/// memory for intermediate allocations. +struct DeviceSharedMemOutlineInfo : public OpenMPIRBuilder::OutlineInfo { + OpenMPIRBuilder &OMPBuilder; + BasicBlock *AllocBlockOverride = nullptr; + + DeviceSharedMemOutlineInfo(OpenMPIRBuilder &OMPBuilder) + : OMPBuilder(OMPBuilder) {} + virtual ~DeviceSharedMemOutlineInfo() = default; + + virtual std::unique_ptr + createCodeExtractor(ArrayRef Blocks, + bool ArgsInZeroAddressSpace, + Twine Suffix = Twine("")) override; +}; + } // anonymous namespace OpenMPIRBuilderConfig::OpenMPIRBuilderConfig() @@ -704,20 +818,20 @@ static void raiseUserConstantDataAllocasToEntryBlock(IRBuilderBase &Builder, void OpenMPIRBuilder::finalize(Function *Fn) { SmallPtrSet ParallelRegionBlockSet; SmallVector Blocks; - SmallVector DeferredOutlines; - for (OutlineInfo &OI : OutlineInfos) { + SmallVector, 16> DeferredOutlines; + for (std::unique_ptr &OI : OutlineInfos) { // Skip functions that have not finalized yet; may happen with nested // function generation. - if (Fn && OI.getFunction() != Fn) { - DeferredOutlines.push_back(OI); + if (Fn && OI->getFunction() != Fn) { + DeferredOutlines.push_back(std::move(OI)); continue; } ParallelRegionBlockSet.clear(); Blocks.clear(); - OI.collectBlocks(ParallelRegionBlockSet, Blocks); + OI->collectBlocks(ParallelRegionBlockSet, Blocks); - Function *OuterFn = OI.getFunction(); + Function *OuterFn = OI->getFunction(); CodeExtractorAnalysisCache CEAC(*OuterFn); // If we generate code for the target device, we need to allocate // struct for aggregate params in the device default alloca address space. @@ -726,26 +840,19 @@ void OpenMPIRBuilder::finalize(Function *Fn) { // CodeExtractor generates correct code for extracted functions // which are used by OpenMP runtime. bool ArgsInZeroAddressSpace = Config.isTargetDevice(); - CodeExtractor Extractor(Blocks, /* DominatorTree */ nullptr, - /* AggregateArgs */ true, - /* BlockFrequencyInfo */ nullptr, - /* BranchProbabilityInfo */ nullptr, - /* AssumptionCache */ nullptr, - /* AllowVarArgs */ true, - /* AllowAlloca */ true, - /* AllocaBlock*/ OI.OuterAllocaBB, - /* Suffix */ ".omp_par", ArgsInZeroAddressSpace); + std::unique_ptr Extractor = + OI->createCodeExtractor(Blocks, ArgsInZeroAddressSpace, ".omp_par"); LLVM_DEBUG(dbgs() << "Before outlining: " << *OuterFn << "\n"); - LLVM_DEBUG(dbgs() << "Entry " << OI.EntryBB->getName() - << " Exit: " << OI.ExitBB->getName() << "\n"); - assert(Extractor.isEligible() && + LLVM_DEBUG(dbgs() << "Entry " << OI->EntryBB->getName() + << " Exit: " << OI->ExitBB->getName() << "\n"); + assert(Extractor->isEligible() && "Expected OpenMP outlining to be possible!"); - for (auto *V : OI.ExcludeArgsFromAggregate) - Extractor.excludeArgFromAggregate(V); + for (auto *V : OI->ExcludeArgsFromAggregate) + Extractor->excludeArgFromAggregate(V); - Function *OutlinedFn = Extractor.extractCodeRegion(CEAC); + Function *OutlinedFn = Extractor->extractCodeRegion(CEAC); // Forward target-cpu, target-features attributes to the outlined function. auto TargetCpuAttr = OuterFn->getFnAttribute("target-cpu"); @@ -770,8 +877,8 @@ void OpenMPIRBuilder::finalize(Function *Fn) { // made our own entry block after all. { BasicBlock &ArtificialEntry = OutlinedFn->getEntryBlock(); - assert(ArtificialEntry.getUniqueSuccessor() == OI.EntryBB); - assert(OI.EntryBB->getUniquePredecessor() == &ArtificialEntry); + assert(ArtificialEntry.getUniqueSuccessor() == OI->EntryBB); + assert(OI->EntryBB->getUniquePredecessor() == &ArtificialEntry); // Move instructions from the to-be-deleted ArtificialEntry to the entry // basic block of the parallel region. CodeExtractor generates // instructions to unwrap the aggregate argument and may sink @@ -787,24 +894,25 @@ void OpenMPIRBuilder::finalize(Function *Fn) { if (I.isTerminator()) { // Absorb any debug value that terminator may have - if (OI.EntryBB->getTerminator()) - OI.EntryBB->getTerminator()->adoptDbgRecords( + if (OI->EntryBB->getTerminator()) + OI->EntryBB->getTerminator()->adoptDbgRecords( &ArtificialEntry, I.getIterator(), false); continue; } - I.moveBeforePreserving(*OI.EntryBB, OI.EntryBB->getFirstInsertionPt()); + I.moveBeforePreserving(*OI->EntryBB, + OI->EntryBB->getFirstInsertionPt()); } - OI.EntryBB->moveBefore(&ArtificialEntry); + OI->EntryBB->moveBefore(&ArtificialEntry); ArtificialEntry.eraseFromParent(); } - assert(&OutlinedFn->getEntryBlock() == OI.EntryBB); + assert(&OutlinedFn->getEntryBlock() == OI->EntryBB); assert(OutlinedFn && OutlinedFn->hasNUses(1)); // Run a user callback, e.g. to add attributes. - if (OI.PostOutlineCB) - OI.PostOutlineCB(*OutlinedFn); + if (OI->PostOutlineCB) + OI->PostOutlineCB(*OutlinedFn); } // Remove work items that have been completed. @@ -1636,31 +1744,71 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createParallel( LLVM_DEBUG(dbgs() << "After body codegen: " << *OuterFn << "\n"); - OutlineInfo OI; + auto OI = [&]() -> std::unique_ptr { + if (Config.isTargetDevice()) { + std::optional ExecMode = + getTargetKernelExecMode(*OuterFn); + + // If OuterFn is not a Generic kernel, skip custom allocation. This causes + // the CodeExtractor to follow its default behavior. Otherwise, we need to + // use device shared memory to allocate argument structures. + if (ExecMode && *ExecMode & OMP_TGT_EXEC_MODE_GENERIC) { + auto Info = std::make_unique(*this); + + // Instead of using the insertion point provided by the CodeExtractor, + // here we need to use the block that eventually calls the outlined + // function for the `parallel` construct. + // + // The reason is that the explicit deallocation call will be inserted + // within the outlined function, whereas the alloca insertion point + // might actually be located somewhere else in the caller. This becomes + // a problem when e.g. `parallel` is inside of a `distribute` construct, + // because the deallocation would be executed multiple times and the + // allocation just once (outside of the loop). + // + // TODO: Ideally, we'd want to do the allocation and deallocation + // outside of the `parallel` outlined function, hence using here the + // insertion point provided by the CodeExtractor. We can't do this at + // the moment because there is currently no way of passing an eligible + // insertion point for the explicit deallocation to the CodeExtractor, + // as that block is created (at least when nested inside of + // `distribute`) sometime after createParallel() completed, so it can't + // be stored in the OutlineInfo structure here. + // + // The current approach results in an explicit allocation and + // deallocation pair for each `distribute` loop iteration in that case, + // which is suboptimal. + Info->AllocBlockOverride = EntryBB; + return Info; + } + } + return std::make_unique(); + }(); + if (Config.isTargetDevice()) { // Generate OpenMP target specific runtime call - OI.PostOutlineCB = [=, ToBeDeletedVec = - std::move(ToBeDeleted)](Function &OutlinedFn) { + OI->PostOutlineCB = [=, ToBeDeletedVec = + std::move(ToBeDeleted)](Function &OutlinedFn) { targetParallelCallback(this, OutlinedFn, OuterFn, OuterAllocaBlock, Ident, IfCondition, NumThreads, PrivTID, PrivTIDAddr, ThreadID, ToBeDeletedVec); }; } else { // Generate OpenMP host runtime call - OI.PostOutlineCB = [=, ToBeDeletedVec = - std::move(ToBeDeleted)](Function &OutlinedFn) { + OI->PostOutlineCB = [=, ToBeDeletedVec = + std::move(ToBeDeleted)](Function &OutlinedFn) { hostParallelCallback(this, OutlinedFn, OuterFn, Ident, IfCondition, PrivTID, PrivTIDAddr, ToBeDeletedVec); }; } - OI.OuterAllocaBB = OuterAllocaBlock; - OI.EntryBB = PRegEntryBB; - OI.ExitBB = PRegExitBB; + OI->OuterAllocaBB = OuterAllocaBlock; + OI->EntryBB = PRegEntryBB; + OI->ExitBB = PRegExitBB; SmallPtrSet ParallelRegionBlockSet; SmallVector Blocks; - OI.collectBlocks(ParallelRegionBlockSet, Blocks); + OI->collectBlocks(ParallelRegionBlockSet, Blocks); CodeExtractorAnalysisCache CEAC(*OuterFn); CodeExtractor Extractor(Blocks, /* DominatorTree */ nullptr, @@ -1671,6 +1819,7 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createParallel( /* AllowVarArgs */ true, /* AllowAlloca */ true, /* AllocationBlock */ OuterAllocaBlock, + /* DeallocationBlock */ nullptr, /* Suffix */ ".omp_par", ArgsInZeroAddressSpace); // Find inputs to, outputs from the code region. @@ -1695,7 +1844,7 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createParallel( auto PrivHelper = [&](Value &V) -> Error { if (&V == TIDAddr || &V == ZeroAddr) { - OI.ExcludeArgsFromAggregate.push_back(&V); + OI->ExcludeArgsFromAggregate.push_back(&V); return Error::success(); } @@ -1972,19 +2121,19 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTask( if (Error Err = BodyGenCB(TaskAllocaIP, TaskBodyIP)) return Err; - OutlineInfo OI; - OI.EntryBB = TaskAllocaBB; - OI.OuterAllocaBB = AllocaIP.getBlock(); - OI.ExitBB = TaskExitBB; + auto OI = std::make_unique(); + OI->EntryBB = TaskAllocaBB; + OI->OuterAllocaBB = AllocaIP.getBlock(); + OI->ExitBB = TaskExitBB; // Add the thread ID argument. SmallVector ToBeDeleted; - OI.ExcludeArgsFromAggregate.push_back(createFakeIntVal( + OI->ExcludeArgsFromAggregate.push_back(createFakeIntVal( Builder, AllocaIP, ToBeDeleted, TaskAllocaIP, "global.tid", false)); - OI.PostOutlineCB = [this, Ident, Tied, Final, IfCondition, Dependencies, - Mergeable, Priority, EventHandle, TaskAllocaBB, - ToBeDeleted](Function &OutlinedFn) mutable { + OI->PostOutlineCB = [this, Ident, Tied, Final, IfCondition, Dependencies, + Mergeable, Priority, EventHandle, TaskAllocaBB, + ToBeDeleted](Function &OutlinedFn) mutable { // Replace the Stale CI by appropriate RTL function call. assert(OutlinedFn.hasOneUse() && "there must be a single user for the outlined function"); @@ -5079,19 +5228,19 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::applyWorkshareLoopTarget( Constant *SrcLocStr = getOrCreateSrcLocStr(DL, SrcLocStrSize); Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize); - OutlineInfo OI; - OI.OuterAllocaBB = CLI->getPreheader(); + auto OI = std::make_unique(); + OI->OuterAllocaBB = CLI->getPreheader(); Function *OuterFn = CLI->getPreheader()->getParent(); // Instructions which need to be deleted at the end of code generation SmallVector ToBeDeleted; - OI.OuterAllocaBB = AllocaIP.getBlock(); + OI->OuterAllocaBB = AllocaIP.getBlock(); // Mark the body loop as region which needs to be extracted - OI.EntryBB = CLI->getBody(); - OI.ExitBB = CLI->getLatch()->splitBasicBlock(CLI->getLatch()->begin(), - "omp.prelatch", true); + OI->EntryBB = CLI->getBody(); + OI->ExitBB = CLI->getLatch()->splitBasicBlock(CLI->getLatch()->begin(), + "omp.prelatch", true); // Prepare loop body for extraction Builder.restoreIP({CLI->getPreheader(), CLI->getPreheader()->begin()}); @@ -5111,7 +5260,7 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::applyWorkshareLoopTarget( // loop body region. SmallPtrSet ParallelRegionBlockSet; SmallVector Blocks; - OI.collectBlocks(ParallelRegionBlockSet, Blocks); + OI->collectBlocks(ParallelRegionBlockSet, Blocks); CodeExtractorAnalysisCache CEAC(*OuterFn); CodeExtractor Extractor(Blocks, @@ -5123,6 +5272,7 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::applyWorkshareLoopTarget( /* AllowVarArgs */ true, /* AllowAlloca */ true, /* AllocationBlock */ CLI->getPreheader(), + /* DeallocationBlock */ nullptr, /* Suffix */ ".omp_wsloop", /* AggrArgsIn0AddrSpace */ true); @@ -5147,15 +5297,15 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::applyWorkshareLoopTarget( } // Make sure that loop counter variable is not merged into loop body // function argument structure and it is passed as separate variable - OI.ExcludeArgsFromAggregate.push_back(NewLoopCntLoad); + OI->ExcludeArgsFromAggregate.push_back(NewLoopCntLoad); // PostOutline CB is invoked when loop body function is outlined and // loop body is replaced by call to outlined function. We need to add // call to OpenMP device rtl inside loop preheader. OpenMP device rtl // function will handle loop control logic. // - OI.PostOutlineCB = [=, ToBeDeletedVec = - std::move(ToBeDeleted)](Function &OutlinedFn) { + OI->PostOutlineCB = [=, ToBeDeletedVec = + std::move(ToBeDeleted)](Function &OutlinedFn) { workshareLoopTargetCallback(this, CLI, Ident, OutlinedFn, ToBeDeletedVec, LoopType, NoLoop); }; @@ -7976,13 +8126,13 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::emitTargetTask( TargetTaskAllocaBB->begin()); InsertPointTy TargetTaskBodyIP(TargetTaskBodyBB, TargetTaskBodyBB->begin()); - OutlineInfo OI; - OI.EntryBB = TargetTaskAllocaBB; - OI.OuterAllocaBB = AllocaIP.getBlock(); + auto OI = std::make_unique(); + OI->EntryBB = TargetTaskAllocaBB; + OI->OuterAllocaBB = AllocaIP.getBlock(); // Add the thread ID argument. SmallVector ToBeDeleted; - OI.ExcludeArgsFromAggregate.push_back(createFakeIntVal( + OI->ExcludeArgsFromAggregate.push_back(createFakeIntVal( Builder, AllocaIP, ToBeDeleted, TargetTaskAllocaIP, "global.tid", false)); // Generate the task body which will subsequently be outlined. @@ -8000,8 +8150,8 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::emitTargetTask( // OI.ExitBlock is set to the single task body block and will get left out of // the outlining process. So, simply create a new empty block to which we // uncoditionally branch from where TaskBodyCB left off - OI.ExitBB = BasicBlock::Create(Builder.getContext(), "target.task.cont"); - emitBlock(OI.ExitBB, Builder.GetInsertBlock()->getParent(), + OI->ExitBB = BasicBlock::Create(Builder.getContext(), "target.task.cont"); + emitBlock(OI->ExitBB, Builder.GetInsertBlock()->getParent(), /*IsFinished=*/true); SmallVector OffloadingArraysToPrivatize; @@ -8013,13 +8163,13 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::emitTargetTask( RTArgs.SizesArray}) { if (V && !isa(V)) { OffloadingArraysToPrivatize.push_back(V); - OI.ExcludeArgsFromAggregate.push_back(V); + OI->ExcludeArgsFromAggregate.push_back(V); } } } - OI.PostOutlineCB = [this, ToBeDeleted, Dependencies, NeedsTargetTask, - DeviceID, OffloadingArraysToPrivatize]( - Function &OutlinedFn) mutable { + OI->PostOutlineCB = [this, ToBeDeleted, Dependencies, NeedsTargetTask, + DeviceID, OffloadingArraysToPrivatize]( + Function &OutlinedFn) mutable { assert(OutlinedFn.hasOneUse() && "there must be a single user for the outlined function"); @@ -9979,17 +10129,17 @@ OpenMPIRBuilder::createTeams(const LocationDescription &Loc, if (Error Err = BodyGenCB(AllocaIP, CodeGenIP)) return Err; - OutlineInfo OI; - OI.EntryBB = AllocaBB; - OI.ExitBB = ExitBB; - OI.OuterAllocaBB = &OuterAllocaBB; + auto OI = std::make_unique(); + OI->EntryBB = AllocaBB; + OI->ExitBB = ExitBB; + OI->OuterAllocaBB = &OuterAllocaBB; // Insert fake values for global tid and bound tid. SmallVector ToBeDeleted; InsertPointTy OuterAllocaIP(&OuterAllocaBB, OuterAllocaBB.begin()); - OI.ExcludeArgsFromAggregate.push_back(createFakeIntVal( + OI->ExcludeArgsFromAggregate.push_back(createFakeIntVal( Builder, OuterAllocaIP, ToBeDeleted, AllocaIP, "gid", true)); - OI.ExcludeArgsFromAggregate.push_back(createFakeIntVal( + OI->ExcludeArgsFromAggregate.push_back(createFakeIntVal( Builder, OuterAllocaIP, ToBeDeleted, AllocaIP, "tid", true)); auto HostPostOutlineCB = [this, Ident, @@ -10029,7 +10179,7 @@ OpenMPIRBuilder::createTeams(const LocationDescription &Loc, }; if (!Config.isTargetDevice()) - OI.PostOutlineCB = HostPostOutlineCB; + OI->PostOutlineCB = HostPostOutlineCB; addOutlineInfo(std::move(OI)); @@ -10068,11 +10218,10 @@ OpenMPIRBuilder::createDistribute(const LocationDescription &Loc, // When using target we use different runtime functions which require a // callback. if (Config.isTargetDevice()) { - OutlineInfo OI; - OI.OuterAllocaBB = OuterAllocaIP.getBlock(); - OI.EntryBB = AllocaBB; - OI.ExitBB = ExitBB; - + auto OI = std::make_unique(); + OI->OuterAllocaBB = OuterAllocaIP.getBlock(); + OI->EntryBB = AllocaBB; + OI->ExitBB = ExitBB; addOutlineInfo(std::move(OI)); } Builder.SetInsertPoint(ExitBB, ExitBB->begin()); @@ -10133,6 +10282,39 @@ void OpenMPIRBuilder::OutlineInfo::collectBlocks( } } +std::unique_ptr +OpenMPIRBuilder::OutlineInfo::createCodeExtractor(ArrayRef Blocks, + bool ArgsInZeroAddressSpace, + Twine Suffix) { + return std::make_unique(Blocks, /* DominatorTree */ nullptr, + /* AggregateArgs */ true, + /* BlockFrequencyInfo */ nullptr, + /* BranchProbabilityInfo */ nullptr, + /* AssumptionCache */ nullptr, + /* AllowVarArgs */ true, + /* AllowAlloca */ true, + /* AllocationBlock*/ OuterAllocaBB, + /* DeallocationBlock */ nullptr, + /* Suffix */ Suffix.str(), + ArgsInZeroAddressSpace); +} + +std::unique_ptr DeviceSharedMemOutlineInfo::createCodeExtractor( + ArrayRef Blocks, bool ArgsInZeroAddressSpace, Twine Suffix) { + // TODO: Initialize the DeallocationBlock with a proper pair to OuterAllocaBB. + return std::make_unique( + OMPBuilder, AllocBlockOverride, Blocks, /* DominatorTree */ nullptr, + /* AggregateArgs */ true, + /* BlockFrequencyInfo */ nullptr, + /* BranchProbabilityInfo */ nullptr, + /* AssumptionCache */ nullptr, + /* AllowVarArgs */ true, + /* AllowAlloca */ true, + /* AllocationBlock*/ OuterAllocaBB, + /* DeallocationBlock */ ExitBB, + /* Suffix */ Suffix.str(), ArgsInZeroAddressSpace); +} + void OpenMPIRBuilder::createOffloadEntry(Constant *ID, Constant *Addr, uint64_t Size, int32_t Flags, GlobalValue::LinkageTypes, diff --git a/llvm/lib/Transforms/IPO/HotColdSplitting.cpp b/llvm/lib/Transforms/IPO/HotColdSplitting.cpp index 3d8b7cbb59630..57809017a75a4 100644 --- a/llvm/lib/Transforms/IPO/HotColdSplitting.cpp +++ b/llvm/lib/Transforms/IPO/HotColdSplitting.cpp @@ -721,6 +721,7 @@ bool HotColdSplitting::outlineColdRegions(Function &F, bool HasProfileSummary) { SubRegion, &*DT, /* AggregateArgs */ false, /* BFI */ nullptr, /* BPI */ nullptr, AC, /* AllowVarArgs */ false, /* AllowAlloca */ false, /* AllocaBlock */ nullptr, + /* DeallocationBlock */ nullptr, /* Suffix */ "cold." + std::to_string(OutlinedFunctionID)); if (CE.isEligible() && isSplittingBeneficial(CE, SubRegion, TTI) && diff --git a/llvm/lib/Transforms/IPO/IROutliner.cpp b/llvm/lib/Transforms/IPO/IROutliner.cpp index fdf0c3ac8007d..177c10ef53040 100644 --- a/llvm/lib/Transforms/IPO/IROutliner.cpp +++ b/llvm/lib/Transforms/IPO/IROutliner.cpp @@ -2826,7 +2826,7 @@ unsigned IROutliner::doOutline(Module &M) { OS->Candidate->getBasicBlocks(BlocksInRegion, BE); OS->CE = new (ExtractorAllocator.Allocate()) CodeExtractor(BE, nullptr, false, nullptr, nullptr, nullptr, false, - false, nullptr, "outlined"); + false, nullptr, nullptr, "outlined"); findAddInputsOutputs(M, *OS, NotSame); if (!OS->IgnoreRegion) OutlinedRegions.push_back(OS); @@ -2937,7 +2937,7 @@ unsigned IROutliner::doOutline(Module &M) { OS->Candidate->getBasicBlocks(BlocksInRegion, BE); OS->CE = new (ExtractorAllocator.Allocate()) CodeExtractor(BE, nullptr, false, nullptr, nullptr, nullptr, false, - false, nullptr, "outlined"); + false, nullptr, nullptr, "outlined"); bool FunctionOutlined = extractSection(*OS); if (FunctionOutlined) { unsigned StartIdx = OS->Candidate->getStartIdx(); diff --git a/llvm/lib/Transforms/Utils/CodeExtractor.cpp b/llvm/lib/Transforms/Utils/CodeExtractor.cpp index bbd1ed6a3ab2d..3339f5e4fea7d 100644 --- a/llvm/lib/Transforms/Utils/CodeExtractor.cpp +++ b/llvm/lib/Transforms/Utils/CodeExtractor.cpp @@ -25,7 +25,6 @@ #include "llvm/Analysis/BranchProbabilityInfo.h" #include "llvm/IR/Argument.h" #include "llvm/IR/Attributes.h" -#include "llvm/IR/BasicBlock.h" #include "llvm/IR/CFG.h" #include "llvm/IR/Constant.h" #include "llvm/IR/Constants.h" @@ -264,11 +263,12 @@ CodeExtractor::CodeExtractor(ArrayRef BBs, DominatorTree *DT, bool AggregateArgs, BlockFrequencyInfo *BFI, BranchProbabilityInfo *BPI, AssumptionCache *AC, bool AllowVarArgs, bool AllowAlloca, - BasicBlock *AllocationBlock, std::string Suffix, + BasicBlock *AllocationBlock, + BasicBlock *DeallocationBlock, std::string Suffix, bool ArgsInZeroAddressSpace) : DT(DT), AggregateArgs(AggregateArgs || AggregateArgsOpt), BFI(BFI), BPI(BPI), AC(AC), AllocationBlock(AllocationBlock), - AllowVarArgs(AllowVarArgs), + DeallocationBlock(DeallocationBlock), AllowVarArgs(AllowVarArgs), Blocks(buildExtractionBlockSet(BBs, DT, AllowVarArgs, AllowAlloca)), Suffix(Suffix), ArgsInZeroAddressSpace(ArgsInZeroAddressSpace) {} @@ -444,6 +444,27 @@ CodeExtractor::findOrCreateBlockForHoisting(BasicBlock *CommonExitBlock) { return CommonExitBlock; } +Instruction *CodeExtractor::allocateVar(BasicBlock *BB, + BasicBlock::iterator AllocIP, + Type *VarType, const Twine &Name, + AddrSpaceCastInst **CastedAlloc) { + const DataLayout &DL = BB->getModule()->getDataLayout(); + Instruction *Alloca = + new AllocaInst(VarType, DL.getAllocaAddrSpace(), nullptr, Name, AllocIP); + + if (CastedAlloc && ArgsInZeroAddressSpace && DL.getAllocaAddrSpace() != 0) { + *CastedAlloc = new AddrSpaceCastInst( + Alloca, PointerType::get(BB->getContext(), 0), Name + ".ascast"); + (*CastedAlloc)->insertAfter(Alloca->getIterator()); + } + return Alloca; +} + +Instruction *CodeExtractor::deallocateVar(BasicBlock *, BasicBlock::iterator, + Value *, Type *) { + return nullptr; +} + // Find the pair of life time markers for address 'Addr' that are either // defined inside the outline region or can legally be shrinkwrapped into the // outline region. If there are not other untracked uses of the address, return @@ -1819,7 +1840,6 @@ CallInst *CodeExtractor::emitReplacerCall( std::vector &Reloads) { LLVMContext &Context = oldFunction->getContext(); Module *M = oldFunction->getParent(); - const DataLayout &DL = M->getDataLayout(); // This takes place of the original loop BasicBlock *codeReplacer = @@ -1850,25 +1870,22 @@ CallInst *CodeExtractor::emitReplacerCall( if (StructValues.contains(output)) continue; - AllocaInst *alloca = new AllocaInst( - output->getType(), DL.getAllocaAddrSpace(), nullptr, - output->getName() + ".loc", AllocaBlock->getFirstInsertionPt()); - params.push_back(alloca); - ReloadOutputs.push_back(alloca); + Value *OutAlloc = + allocateVar(AllocaBlock, AllocaBlock->getFirstInsertionPt(), + output->getType(), output->getName() + ".loc"); + params.push_back(OutAlloc); + ReloadOutputs.push_back(OutAlloc); } - AllocaInst *Struct = nullptr; + Instruction *Struct = nullptr; if (!StructValues.empty()) { - Struct = new AllocaInst(StructArgTy, DL.getAllocaAddrSpace(), nullptr, - "structArg", AllocaBlock->getFirstInsertionPt()); - if (ArgsInZeroAddressSpace && DL.getAllocaAddrSpace() != 0) { - auto *StructSpaceCast = new AddrSpaceCastInst( - Struct, PointerType ::get(Context, 0), "structArg.ascast"); - StructSpaceCast->insertAfter(Struct->getIterator()); + AddrSpaceCastInst *StructSpaceCast = nullptr; + Struct = allocateVar(AllocaBlock, AllocaBlock->getFirstInsertionPt(), + StructArgTy, "structArg", &StructSpaceCast); + if (StructSpaceCast) params.push_back(StructSpaceCast); - } else { + else params.push_back(Struct); - } unsigned AggIdx = 0; for (Value *input : inputs) { @@ -2011,6 +2028,24 @@ CallInst *CodeExtractor::emitReplacerCall( insertLifetimeMarkersSurroundingCall(oldFunction->getParent(), LifetimesStart, {}, call); + // Deallocate intermediate variables if they need explicit deallocation. + BasicBlock *DeallocBlock = codeReplacer; + BasicBlock::iterator DeallocIP = codeReplacer->end(); + if (DeallocationBlock) { + DeallocBlock = DeallocationBlock; + DeallocIP = DeallocationBlock->getFirstInsertionPt(); + } + + int Index = 0; + for (Value *Output : outputs) { + if (!StructValues.contains(Output)) + deallocateVar(DeallocBlock, DeallocIP, ReloadOutputs[Index++], + Output->getType()); + } + + if (Struct) + deallocateVar(DeallocBlock, DeallocIP, Struct, StructArgTy); + return call; } diff --git a/llvm/unittests/Transforms/Utils/CodeExtractorTest.cpp b/llvm/unittests/Transforms/Utils/CodeExtractorTest.cpp index 9ea8de3da1e5b..6fd266a815dcf 100644 --- a/llvm/unittests/Transforms/Utils/CodeExtractorTest.cpp +++ b/llvm/unittests/Transforms/Utils/CodeExtractorTest.cpp @@ -711,7 +711,8 @@ TEST(CodeExtractor, OpenMPAggregateArgs) { /* AssumptionCache */ nullptr, /* AllowVarArgs */ true, /* AllowAlloca */ true, - /* AllocaBlock*/ &Func->getEntryBlock(), + /* AllocationBlock*/ &Func->getEntryBlock(), + /* DeallocationBlock */ nullptr, /* Suffix */ ".outlined", /* ArgsInZeroAddressSpace */ true); diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp index 0fa47d3f48a83..61ddc8339b692 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp @@ -5953,6 +5953,7 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder, static LogicalResult convertDeclareTargetAttr(Operation *op, mlir::omp::DeclareTargetAttr attribute, + llvm::OpenMPIRBuilder *ompBuilder, LLVM::ModuleTranslation &moduleTranslation) { // Amend omp.declare_target by deleting the IR of the outlined functions // created for target regions. They cannot be filtered out from MLIR earlier @@ -5975,6 +5976,11 @@ convertDeclareTargetAttr(Operation *op, mlir::omp::DeclareTargetAttr attribute, moduleTranslation.lookupFunction(funcOp.getName()); llvmFunc->dropAllReferences(); llvmFunc->eraseFromParent(); + + // Invalidate the builder's current insertion point, as it now points to + // a deleted block. + ompBuilder->Builder.ClearInsertionPoint(); + ompBuilder->Builder.SetCurrentDebugLocation(llvm::DebugLoc()); } } return success(); @@ -6512,9 +6518,12 @@ LogicalResult OpenMPDialectLLVMIRTranslationInterface::amendOperation( .Case("omp.declare_target", [&](Attribute attr) { if (auto declareTargetAttr = - dyn_cast(attr)) + dyn_cast(attr)) { + llvm::OpenMPIRBuilder *ompBuilder = + moduleTranslation.getOpenMPBuilder(); return convertDeclareTargetAttr(op, declareTargetAttr, - moduleTranslation); + ompBuilder, moduleTranslation); + } return failure(); }) .Case("omp.requires", diff --git a/mlir/test/Target/LLVMIR/omptarget-parallel-llvm.mlir b/mlir/test/Target/LLVMIR/omptarget-parallel-llvm.mlir index 60c6fa4dd8f1e..504e39c96f008 100644 --- a/mlir/test/Target/LLVMIR/omptarget-parallel-llvm.mlir +++ b/mlir/test/Target/LLVMIR/omptarget-parallel-llvm.mlir @@ -56,8 +56,6 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"dlti.alloca_memo // CHECK-SAME: ptr %[[TMP:.*]], ptr %[[TMP0:.*]]) #{{[0-9]+}} { // CHECK: %[[TMP1:.*]] = alloca [1 x ptr], align 8, addrspace(5) // CHECK: %[[TMP2:.*]] = addrspacecast ptr addrspace(5) %[[TMP1]] to ptr -// CHECK: %[[STRUCTARG:.*]] = alloca { ptr }, align 8, addrspace(5) -// CHECK: %[[STRUCTARG_ASCAST:.*]] = addrspacecast ptr addrspace(5) %[[STRUCTARG]] to ptr // CHECK: %[[TMP3:.*]] = alloca ptr, align 8, addrspace(5) // CHECK: %[[TMP4:.*]] = addrspacecast ptr addrspace(5) %[[TMP3]] to ptr // CHECK: store ptr %[[TMP0]], ptr %[[TMP4]], align 8 @@ -65,12 +63,14 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"dlti.alloca_memo // CHECK: %[[EXEC_USER_CODE:.*]] = icmp eq i32 %[[TMP5]], -1 // CHECK: br i1 %[[EXEC_USER_CODE]], label %[[USER_CODE_ENTRY:.*]], label %[[WORKER_EXIT:.*]] // CHECK: %[[TMP6:.*]] = load ptr, ptr %[[TMP4]], align 8 +// CHECK: %[[STRUCTARG:.*]] = call align 8 ptr @__kmpc_alloc_shared(i64 8) // CHECK: %[[OMP_GLOBAL_THREAD_NUM:.*]] = call i32 @__kmpc_global_thread_num(ptr addrspacecast (ptr addrspace(1) @[[GLOB1:[0-9]+]] to ptr)) -// CHECK: %[[GEP_:.*]] = getelementptr { ptr }, ptr addrspace(5) %[[STRUCTARG]], i32 0, i32 0 -// CHECK: store ptr %[[TMP6]], ptr addrspace(5) %[[GEP_]], align 8 +// CHECK: %[[GEP_:.*]] = getelementptr { ptr }, ptr %[[STRUCTARG]], i32 0, i32 0 +// CHECK: store ptr %[[TMP6]], ptr %[[GEP_]], align 8 // CHECK: %[[TMP7:.*]] = getelementptr inbounds [1 x ptr], ptr %[[TMP2]], i64 0, i64 0 -// CHECK: store ptr %[[STRUCTARG_ASCAST]], ptr %[[TMP7]], align 8 +// CHECK: store ptr %[[STRUCTARG]], ptr %[[TMP7]], align 8 // CHECK: call void @__kmpc_parallel_51(ptr addrspacecast (ptr addrspace(1) @[[GLOB1]] to ptr), i32 %[[OMP_GLOBAL_THREAD_NUM]], i32 1, i32 -1, i32 -1, ptr @[[FUNC1:.*]], ptr null, ptr %[[TMP2]], i64 1) +// CHECK: call void @__kmpc_free_shared(ptr %[[STRUCTARG]], i64 8) // CHECK: call void @__kmpc_target_deinit() // CHECK: define internal void @[[FUNC1]](