diff --git a/llvm/include/llvm/Transforms/IPO/BlockExtractor.h b/llvm/include/llvm/Transforms/IPO/BlockExtractor.h index cf6b1666b4fc6..b8849b3d841fe 100644 --- a/llvm/include/llvm/Transforms/IPO/BlockExtractor.h +++ b/llvm/include/llvm/Transforms/IPO/BlockExtractor.h @@ -23,12 +23,13 @@ class BasicBlock; struct BlockExtractorPass : PassInfoMixin { BlockExtractorPass(std::vector> &&GroupsOfBlocks, - bool EraseFunctions); + bool EraseFunctions, bool KeepOldBlocks); PreservedAnalyses run(Module &M, ModuleAnalysisManager &AM); private: std::vector> GroupsOfBlocks; bool EraseFunctions; + bool KeepOldBlocks; }; } // namespace llvm diff --git a/llvm/include/llvm/Transforms/Utils/Cloning.h b/llvm/include/llvm/Transforms/Utils/Cloning.h index 1e8ef0102450e..6b6bead270d72 100644 --- a/llvm/include/llvm/Transforms/Utils/Cloning.h +++ b/llvm/include/llvm/Transforms/Utils/Cloning.h @@ -117,10 +117,16 @@ struct ClonedCodeInfo { /// If you would like to collect additional information about the cloned /// function, you can specify a ClonedCodeInfo object with the optional fifth /// parameter. -BasicBlock *CloneBasicBlock(const BasicBlock *BB, ValueToValueMapTy &VMap, - const Twine &NameSuffix = "", Function *F = nullptr, - ClonedCodeInfo *CodeInfo = nullptr, - DebugInfoFinder *DIFinder = nullptr); +/// +/// If you would like to clone only a subset of instructions in the basic block, +/// you can specify a callback that returns true only for those instructions +/// that are to be cloned with the optional seventh paramter. +BasicBlock * +CloneBasicBlock(const BasicBlock *BB, ValueToValueMapTy &VMap, + const Twine &NameSuffix = "", Function *F = nullptr, + ClonedCodeInfo *CodeInfo = nullptr, + DebugInfoFinder *DIFinder = nullptr, + function_ref InstSelect = {}); /// Return a copy of the specified function and add it to that /// function's module. Also, any references specified in the VMap are changed diff --git a/llvm/include/llvm/Transforms/Utils/CodeExtractor.h b/llvm/include/llvm/Transforms/Utils/CodeExtractor.h index 826347e79f719..e40d2244fb4b5 100644 --- a/llvm/include/llvm/Transforms/Utils/CodeExtractor.h +++ b/llvm/include/llvm/Transforms/Utils/CodeExtractor.h @@ -35,6 +35,7 @@ class Instruction; class Module; class Type; class Value; +class StructType; /// A cache for the CodeExtractor analysis. The operation \ref /// CodeExtractor::extractCodeRegion is guaranteed not to invalidate this @@ -99,14 +100,19 @@ class CodeExtractorAnalysisCache { // If true, varargs functions can be extracted. bool AllowVarArgs; + /// If true, copies the code into the extracted function instead of moving + /// it. + bool KeepOldBlocks; + // Bits of intermediate state computed at various phases of extraction. SetVector Blocks; - unsigned NumExitBlocks = std::numeric_limits::max(); - Type *RetTy; - // Mapping from the original exit blocks, to the new blocks inside - // the function. - SmallVector OldTargets; + /// Lists of blocks that are branched from the code region to be extracted. + /// Each block is contained at most once. Its order defines the return value + /// of the extracted function, when leaving the extracted function via the + /// first block it returns 0. When leaving via the second entry it returns + /// 1, etc. + SmallVector SwitchCases; // Suffix to use when creating extracted function (appended to the original // function name + "."). If empty, the default is to use the entry block @@ -134,13 +140,19 @@ class CodeExtractorAnalysisCache { /// If ArgsInZeroAddressSpace param is set to true, then the aggregate /// param pointer of the outlined function is declared in zero address /// space. + /// + /// If KeepOldBlocks is true, the original instances of the extracted region + /// remains in the original function so they can still be branched to from + /// non-extracted blocks. However, only branches to the first block will + /// call the extracted function. CodeExtractor(ArrayRef BBs, DominatorTree *DT = nullptr, bool AggregateArgs = false, BlockFrequencyInfo *BFI = nullptr, BranchProbabilityInfo *BPI = nullptr, AssumptionCache *AC = nullptr, bool AllowVarArgs = false, bool AllowAlloca = false, BasicBlock *AllocationBlock = nullptr, - std::string Suffix = "", bool ArgsInZeroAddressSpace = false); + std::string Suffix = "", bool ArgsInZeroAddressSpace = false, + bool KeepOldBlocks = false); /// Perform the extraction, returning the new function. /// @@ -238,26 +250,61 @@ class CodeExtractorAnalysisCache { getLifetimeMarkers(const CodeExtractorAnalysisCache &CEAC, Instruction *Addr, BasicBlock *ExitBlock) const; + /// Updates the list of SwitchCases (corresponding to exit blocks) after + /// changes of the control flow or the Blocks list. + void recomputeSwitchCases(); + + /// Return the type used for the return code of the extracted function to + /// indicate which exit block to jump to. + Type *getSwitchType(); + void severSplitPHINodesOfEntry(BasicBlock *&Header); - void severSplitPHINodesOfExits(const SetVector &Exits); + void severSplitPHINodesOfExits(); void splitReturnBlocks(); - Function *constructFunction(const ValueSet &inputs, - const ValueSet &outputs, - BasicBlock *header, - BasicBlock *newRootNode, BasicBlock *newHeader, - Function *oldFunction, Module *M); - void moveCodeToFunction(Function *newFunction); void calculateNewCallTerminatorWeights( BasicBlock *CodeReplacer, - DenseMap &ExitWeights, + const DenseMap &ExitWeights, BranchProbabilityInfo *BPI); - CallInst *emitCallAndSwitchStatement(Function *newFunction, - BasicBlock *newHeader, - ValueSet &inputs, ValueSet &outputs); + /// Normalizes the control flow of the extracted regions, such as ensuring + /// that the extracted region does not contain a return instruction. + void normalizeCFGForExtraction(BasicBlock *&header); + + /// Generates the function declaration for the function containing the + /// extracted code. + Function *constructFunctionDeclaration(const ValueSet &inputs, + const ValueSet &outputs, + BlockFrequency EntryFreq, + const Twine &Name, + ValueSet &StructValues, + StructType *&StructTy); + + /// Generates the code for the extracted function. That is: a prolog, the + /// moved or copied code from the original function, and epilogs for each + /// exit. + void emitFunctionBody(const ValueSet &inputs, const ValueSet &outputs, + const ValueSet &StructValues, Function *newFunction, + StructType *StructArgTy, BasicBlock *header, + const ValueSet &SinkingCands); + + /// Generates a Basic Block that calls the extracted function. + CallInst *emitReplacerCall(const ValueSet &inputs, const ValueSet &outputs, + const ValueSet &StructValues, + Function *newFunction, StructType *StructArgTy, + Function *oldFunction, BasicBlock *ReplIP, + BlockFrequency EntryFreq, + ArrayRef LifetimesStart, + std::vector &Reloads); + + /// Connects the basic block containing the call to the extracted function + /// into the original function's control flow. + void insertReplacerCall( + Function *oldFunction, BasicBlock *header, BasicBlock *codeReplacer, + const ValueSet &outputs, ArrayRef Reloads, + const DenseMap &ExitWeights); }; } // end namespace llvm diff --git a/llvm/lib/Passes/PassRegistry.def b/llvm/lib/Passes/PassRegistry.def index 017ae311c55eb..873a789b58727 100644 --- a/llvm/lib/Passes/PassRegistry.def +++ b/llvm/lib/Passes/PassRegistry.def @@ -65,7 +65,7 @@ MODULE_PASS("dfsan", DataFlowSanitizerPass()) MODULE_PASS("dot-callgraph", CallGraphDOTPrinterPass()) MODULE_PASS("dxil-upgrade", DXILUpgradePass()) MODULE_PASS("elim-avail-extern", EliminateAvailableExternallyPass()) -MODULE_PASS("extract-blocks", BlockExtractorPass({}, false)) +MODULE_PASS("extract-blocks", BlockExtractorPass({}, false, false)) MODULE_PASS("expand-variadics", ExpandVariadicsPass(ExpandVariadicsMode::Disable)) MODULE_PASS("forceattrs", ForceFunctionAttrsPass()) MODULE_PASS("function-import", FunctionImportPass()) diff --git a/llvm/lib/Transforms/IPO/BlockExtractor.cpp b/llvm/lib/Transforms/IPO/BlockExtractor.cpp index ec1be35a33164..e96155f353f8e 100644 --- a/llvm/lib/Transforms/IPO/BlockExtractor.cpp +++ b/llvm/lib/Transforms/IPO/BlockExtractor.cpp @@ -41,7 +41,8 @@ static cl::opt namespace { class BlockExtractor { public: - BlockExtractor(bool EraseFunctions) : EraseFunctions(EraseFunctions) {} + BlockExtractor(bool EraseFunctions, bool KeepOldBlocks) + : EraseFunctions(EraseFunctions), KeepOldBlocks(KeepOldBlocks) {} bool runOnModule(Module &M); void init(const std::vector> &GroupsOfBlocksToExtract) { @@ -53,6 +54,7 @@ class BlockExtractor { private: std::vector> GroupsOfBlocks; bool EraseFunctions; + bool KeepOldBlocks; /// Map a function name to groups of blocks. SmallVector>, 4> BlocksByName; @@ -169,7 +171,19 @@ bool BlockExtractor::runOnModule(Module &M) { Changed = true; } CodeExtractorAnalysisCache CEAC(*BBs[0]->getParent()); - Function *F = CodeExtractor(BlocksToExtractVec).extractCodeRegion(CEAC); + Function *F = CodeExtractor(BlocksToExtractVec, + /* DT */ nullptr, + /* AggregateArgs*/ false, + /* BFI */ nullptr, + /* BPI */ nullptr, + /* AC */ nullptr, + /* AllowVarArgs */ false, + /* AllowAlloca */ false, + /* AllocationBlock */ nullptr, + /* Suffix */ "", + /* ArgsInZeroAddressSpace */ false, + /* KeepOldBlocks */ KeepOldBlocks) + .extractCodeRegion(CEAC); if (F) LLVM_DEBUG(dbgs() << "Extracted group '" << (*BBs.begin())->getName() << "' in: " << F->getName() << '\n'); @@ -196,12 +210,13 @@ bool BlockExtractor::runOnModule(Module &M) { BlockExtractorPass::BlockExtractorPass( std::vector> &&GroupsOfBlocks, - bool EraseFunctions) - : GroupsOfBlocks(GroupsOfBlocks), EraseFunctions(EraseFunctions) {} + bool EraseFunctions, bool KeepOldBlocks) + : GroupsOfBlocks(GroupsOfBlocks), EraseFunctions(EraseFunctions), + KeepOldBlocks(KeepOldBlocks) {} PreservedAnalyses BlockExtractorPass::run(Module &M, ModuleAnalysisManager &AM) { - BlockExtractor BE(EraseFunctions); + BlockExtractor BE(EraseFunctions, KeepOldBlocks); BE.init(GroupsOfBlocks); return BE.runOnModule(M) ? PreservedAnalyses::none() : PreservedAnalyses::all(); diff --git a/llvm/lib/Transforms/Utils/CloneFunction.cpp b/llvm/lib/Transforms/Utils/CloneFunction.cpp index a2d38717f38d1..a7aa7dd44f548 100644 --- a/llvm/lib/Transforms/Utils/CloneFunction.cpp +++ b/llvm/lib/Transforms/Utils/CloneFunction.cpp @@ -41,10 +41,11 @@ using namespace llvm; #define DEBUG_TYPE "clone-function" /// See comments in Cloning.h. -BasicBlock *llvm::CloneBasicBlock(const BasicBlock *BB, ValueToValueMapTy &VMap, - const Twine &NameSuffix, Function *F, - ClonedCodeInfo *CodeInfo, - DebugInfoFinder *DIFinder) { +BasicBlock * +llvm::CloneBasicBlock(const BasicBlock *BB, ValueToValueMapTy &VMap, + const Twine &NameSuffix, Function *F, + ClonedCodeInfo *CodeInfo, DebugInfoFinder *DIFinder, + function_ref InstSelect) { BasicBlock *NewBB = BasicBlock::Create(BB->getContext(), "", F); NewBB->IsNewDbgInfoFormat = BB->IsNewDbgInfoFormat; if (BB->hasName()) @@ -55,6 +56,9 @@ BasicBlock *llvm::CloneBasicBlock(const BasicBlock *BB, ValueToValueMapTy &VMap, // Loop over all instructions, and copy them over. for (const Instruction &I : *BB) { + if (InstSelect && !InstSelect(&I)) + continue; + if (DIFinder && TheModule) DIFinder->processInstruction(*TheModule, I); diff --git a/llvm/lib/Transforms/Utils/CodeExtractor.cpp b/llvm/lib/Transforms/Utils/CodeExtractor.cpp index fa467cc72bd02..b6906ab1f2207 100644 --- a/llvm/lib/Transforms/Utils/CodeExtractor.cpp +++ b/llvm/lib/Transforms/Utils/CodeExtractor.cpp @@ -59,6 +59,8 @@ #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" +#include "llvm/Transforms/Utils/Cloning.h" +#include "llvm/Transforms/Utils/SSAUpdater.h" #include #include #include @@ -193,7 +195,8 @@ static bool isBlockValidForExtraction(const BasicBlock &BB, /// Build a set of blocks to extract if the input blocks are viable. static SetVector buildExtractionBlockSet(ArrayRef BBs, DominatorTree *DT, - bool AllowVarArgs, bool AllowAlloca) { + bool AllowVarArgs, bool AllowAlloca, + bool KeepOldBlocks) { assert(!BBs.empty() && "The set of blocks to extract must be non-empty"); SetVector Result; @@ -225,16 +228,20 @@ buildExtractionBlockSet(ArrayRef BBs, DominatorTree *DT, } // All blocks other than the first must not have predecessors outside of - // the subgraph which is being extracted. - for (auto *PBB : predecessors(BB)) - if (!Result.count(PBB)) { - LLVM_DEBUG(dbgs() << "No blocks in this region may have entries from " - "outside the region except for the first block!\n" - << "Problematic source BB: " << BB->getName() << "\n" - << "Problematic destination BB: " << PBB->getName() - << "\n"); - return {}; - } + // the subgraph which is being extracted. KeepOldBlocks relaxes this + // requirement. + if (!KeepOldBlocks) { + for (auto *PBB : predecessors(BB)) + if (!Result.count(PBB)) { + LLVM_DEBUG(dbgs() + << "No blocks in this region may have entries from " + "outside the region except for the first block!\n" + << "Problematic source BB: " << BB->getName() << "\n" + << "Problematic destination BB: " << PBB->getName() + << "\n"); + return {}; + } + } } return Result; @@ -245,11 +252,12 @@ CodeExtractor::CodeExtractor(ArrayRef BBs, DominatorTree *DT, BranchProbabilityInfo *BPI, AssumptionCache *AC, bool AllowVarArgs, bool AllowAlloca, BasicBlock *AllocationBlock, std::string Suffix, - bool ArgsInZeroAddressSpace) + bool ArgsInZeroAddressSpace, bool KeepOldBlocks) : DT(DT), AggregateArgs(AggregateArgs || AggregateArgsOpt), BFI(BFI), BPI(BPI), AC(AC), AllocationBlock(AllocationBlock), - AllowVarArgs(AllowVarArgs), - Blocks(buildExtractionBlockSet(BBs, DT, AllowVarArgs, AllowAlloca)), + AllowVarArgs(AllowVarArgs), KeepOldBlocks(KeepOldBlocks), + Blocks(buildExtractionBlockSet(BBs, DT, AllowVarArgs, AllowAlloca, + KeepOldBlocks)), Suffix(Suffix), ArgsInZeroAddressSpace(ArgsInZeroAddressSpace) {} /// definedInRegion - Return true if the specified value is defined in the @@ -421,7 +429,6 @@ CodeExtractor::findOrCreateBlockForHoisting(BasicBlock *CommonExitBlock) { } // Now add the old exit block to the outline region. Blocks.insert(CommonExitBlock); - OldTargets.push_back(NewExitBlock); return CommonExitBlock; } @@ -638,6 +645,10 @@ void CodeExtractor::findInputsOutputs(ValueSet &Inputs, ValueSet &Outputs, // If a used value is defined outside the region, it's an input. If an // instruction is used outside the region, it's an output. for (Instruction &II : *BB) { + // Ignore assumptions if not been removed yet. + if (isa(II)) + continue; + for (auto &OI : II.operands()) { Value *V = OI; if (!SinkCands.count(V) && @@ -735,9 +746,8 @@ void CodeExtractor::severSplitPHINodesOfEntry(BasicBlock *&Header) { /// outlined region, we split these PHIs on two: one with inputs from region /// and other with remaining incoming blocks; then first PHIs are placed in /// outlined region. -void CodeExtractor::severSplitPHINodesOfExits( - const SetVector &Exits) { - for (BasicBlock *ExitBB : Exits) { +void CodeExtractor::severSplitPHINodesOfExits() { + for (BasicBlock *ExitBB : SwitchCases) { BasicBlock *NewBB = nullptr; for (PHINode &PN : ExitBB->phis()) { @@ -801,44 +811,28 @@ void CodeExtractor::splitReturnBlocks() { } } -/// constructFunction - make a function based on inputs and outputs, as follows: -/// f(in0, ..., inN, out0, ..., outN) -Function *CodeExtractor::constructFunction(const ValueSet &inputs, - const ValueSet &outputs, - BasicBlock *header, - BasicBlock *newRootNode, - BasicBlock *newHeader, - Function *oldFunction, - Module *M) { +Function *CodeExtractor::constructFunctionDeclaration( + const ValueSet &inputs, const ValueSet &outputs, BlockFrequency EntryFreq, + const Twine &Name, ValueSet &StructValues, StructType *&StructTy) { LLVM_DEBUG(dbgs() << "inputs: " << inputs.size() << "\n"); LLVM_DEBUG(dbgs() << "outputs: " << outputs.size() << "\n"); - // This function returns unsigned, outputs will go back by reference. - switch (NumExitBlocks) { - case 0: - case 1: RetTy = Type::getVoidTy(header->getContext()); break; - case 2: RetTy = Type::getInt1Ty(header->getContext()); break; - default: RetTy = Type::getInt16Ty(header->getContext()); break; - } + Function *oldFunction = Blocks.front()->getParent(); + Module *M = Blocks.front()->getModule(); + // Assemble the function's parameter lists. std::vector ParamTy; std::vector AggParamTy; - std::vector> NumberedInputs; - std::vector> NumberedOutputs; - ValueSet StructValues; const DataLayout &DL = M->getDataLayout(); // Add the types of the input values to the function's argument list - unsigned ArgNum = 0; for (Value *value : inputs) { LLVM_DEBUG(dbgs() << "value used in func: " << *value << "\n"); if (AggregateArgs && !ExcludeArgsFromAggregate.contains(value)) { AggParamTy.push_back(value->getType()); StructValues.insert(value); - } else { + } else ParamTy.push_back(value->getType()); - NumberedInputs.emplace_back(ArgNum++, value); - } } // Add the types of the output values to the function's argument list. @@ -847,11 +841,9 @@ Function *CodeExtractor::constructFunction(const ValueSet &inputs, if (AggregateArgs && !ExcludeArgsFromAggregate.contains(output)) { AggParamTy.push_back(output->getType()); StructValues.insert(output); - } else { + } else ParamTy.push_back( PointerType::get(output->getType(), DL.getAllocaAddrSpace())); - NumberedOutputs.emplace_back(ArgNum++, output); - } } assert( @@ -862,14 +854,13 @@ Function *CodeExtractor::constructFunction(const ValueSet &inputs, "Expeced StructValues only with AggregateArgs set"); // Concatenate scalar and aggregate params in ParamTy. - size_t NumScalarParams = ParamTy.size(); - StructType *StructTy = nullptr; - if (AggregateArgs && !AggParamTy.empty()) { + if (!AggParamTy.empty()) { StructTy = StructType::get(M->getContext(), AggParamTy); ParamTy.push_back(PointerType::get( StructTy, ArgsInZeroAddressSpace ? 0 : DL.getAllocaAddrSpace())); } + Type *RetTy = getSwitchType(); LLVM_DEBUG({ dbgs() << "Function type: " << *RetTy << " f("; for (Type *i : ParamTy) @@ -880,15 +871,14 @@ Function *CodeExtractor::constructFunction(const ValueSet &inputs, FunctionType *funcType = FunctionType::get( RetTy, ParamTy, AllowVarArgs && oldFunction->isVarArg()); - std::string SuffixToUse = - Suffix.empty() - ? (header->getName().empty() ? "extracted" : header->getName().str()) - : Suffix; // Create the new function - Function *newFunction = Function::Create( - funcType, GlobalValue::InternalLinkage, oldFunction->getAddressSpace(), - oldFunction->getName() + "." + SuffixToUse, M); - newFunction->IsNewDbgInfoFormat = oldFunction->IsNewDbgInfoFormat; + Function *newFunction = + Function::Create(funcType, GlobalValue::InternalLinkage, + oldFunction->getAddressSpace(), Name, M); + + // Propagate personality info to the new function if there is one. + if (oldFunction->hasPersonalityFn()) + newFunction->setPersonalityFn(oldFunction->getPersonalityFn()); // Inherit all of the target dependent attributes and white-listed // target independent attributes. @@ -1017,69 +1007,56 @@ Function *CodeExtractor::constructFunction(const ValueSet &inputs, newFunction->addFnAttr(Attr); } - if (NumExitBlocks == 0) { - // Mark the new function `noreturn` if applicable. Terminators which resume - // exception propagation are treated as returning instructions. This is to - // avoid inserting traps after calls to outlined functions which unwind. - if (none_of(Blocks, [](const BasicBlock *BB) { - const Instruction *Term = BB->getTerminator(); - return isa(Term) || isa(Term); - })) - newFunction->setDoesNotReturn(); - } - - newFunction->insert(newFunction->end(), newRootNode); - // Create scalar and aggregate iterators to name all of the arguments we // inserted. Function::arg_iterator ScalarAI = newFunction->arg_begin(); - Function::arg_iterator AggAI = std::next(ScalarAI, NumScalarParams); - // Rewrite all users of the inputs in the extracted region to use the - // arguments (or appropriate addressing into struct) instead. - for (unsigned i = 0, e = inputs.size(), aggIdx = 0; i != e; ++i) { - Value *RewriteVal; - if (AggregateArgs && StructValues.contains(inputs[i])) { - Value *Idx[2]; - Idx[0] = Constant::getNullValue(Type::getInt32Ty(header->getContext())); - Idx[1] = ConstantInt::get(Type::getInt32Ty(header->getContext()), aggIdx); - BasicBlock::iterator TI = newFunction->begin()->getTerminator()->getIterator(); - GetElementPtrInst *GEP = GetElementPtrInst::Create( - StructTy, &*AggAI, Idx, "gep_" + inputs[i]->getName(), TI); - RewriteVal = new LoadInst(StructTy->getElementType(aggIdx), GEP, - "loadgep_" + inputs[i]->getName(), TI); - ++aggIdx; - } else - RewriteVal = &*ScalarAI++; + // Set names and attributes for input and output arguments. + ScalarAI = newFunction->arg_begin(); + for (Value *input : inputs) { + if (StructValues.contains(input)) + continue; - std::vector Users(inputs[i]->user_begin(), inputs[i]->user_end()); - for (User *use : Users) - if (Instruction *inst = dyn_cast(use)) - if (Blocks.count(inst->getParent())) - inst->replaceUsesOfWith(inputs[i], RewriteVal); + ScalarAI->setName(input->getName()); + if (input->isSwiftError()) + newFunction->addParamAttr(ScalarAI - newFunction->arg_begin(), + Attribute::SwiftError); + ++ScalarAI; } + for (Value *output : outputs) { + if (StructValues.contains(output)) + continue; - // Set names for input and output arguments. - for (auto [i, argVal] : NumberedInputs) - newFunction->getArg(i)->setName(argVal->getName()); - for (auto [i, argVal] : NumberedOutputs) - newFunction->getArg(i)->setName(argVal->getName() + ".out"); + ScalarAI->setName(output->getName() + ".out"); + ++ScalarAI; + } - // Rewrite branches to basic blocks outside of the loop to new dummy blocks - // within the new function. This must be done before we lose track of which - // blocks were originally in the code region. - std::vector Users(header->user_begin(), header->user_end()); - for (auto &U : Users) - // The BasicBlock which contains the branch is not in the region - // modify the branch target to a new block - if (Instruction *I = dyn_cast(U)) - if (I->isTerminator() && I->getFunction() == oldFunction && - !Blocks.count(I->getParent())) - I->replaceUsesOfWith(header, newHeader); + // Update the entry count of the function. + if (BFI) { + auto Count = BFI->getProfileCountFromFreq(EntryFreq); + if (Count.has_value()) + newFunction->setEntryCount( + ProfileCount(*Count, Function::PCT_Real)); // FIXME + } return newFunction; } +static void applyFirstDebugLoc(Function *oldFunction, + ArrayRef Blocks, + Instruction *BranchI) { + if (oldFunction->getSubprogram()) { + any_of(Blocks, [&BranchI](const BasicBlock *BB) { + return any_of(*BB, [&BranchI](const Instruction &I) { + if (!I.getDebugLoc()) + return false; + BranchI->setDebugLoc(I.getDebugLoc()); + return true; + }); + }); + } +} + /// Erase lifetime.start markers which reference inputs to the extraction /// region, and insert the referenced memory into \p LifetimesStart. /// @@ -1148,410 +1125,103 @@ static void insertLifetimeMarkersSurroundingCall( } } -/// emitCallAndSwitchStatement - This method sets up the caller side by adding -/// the call instruction, splitting any PHI nodes in the header block as -/// necessary. -CallInst *CodeExtractor::emitCallAndSwitchStatement(Function *newFunction, - BasicBlock *codeReplacer, - ValueSet &inputs, - ValueSet &outputs) { - // Emit a call to the new function, passing in: *pointer to struct (if - // aggregating parameters), or plan inputs and allocated memory for outputs - std::vector params, ReloadOutputs, Reloads; - ValueSet StructValues; - - Module *M = newFunction->getParent(); - LLVMContext &Context = M->getContext(); - const DataLayout &DL = M->getDataLayout(); - CallInst *call = nullptr; +void CodeExtractor::moveCodeToFunction(Function *newFunction) { + auto newFuncIt = newFunction->begin(); + for (BasicBlock *Block : Blocks) { + // Delete the basic block from the old function, and the list of blocks + Block->removeFromParent(); - // Add inputs as params, or to be filled into the struct - unsigned ScalarInputArgNo = 0; - SmallVector SwiftErrorArgs; - for (Value *input : inputs) { - if (AggregateArgs && !ExcludeArgsFromAggregate.contains(input)) - StructValues.insert(input); - else { - params.push_back(input); - if (input->isSwiftError()) - SwiftErrorArgs.push_back(ScalarInputArgNo); - } - ++ScalarInputArgNo; + // Insert this basic block into the new function + // Insert the original blocks after the entry block created + // for the new function. The entry block may be followed + // by a set of exit blocks at this point, but these exit + // blocks better be placed at the end of the new function. + newFuncIt = newFunction->insert(std::next(newFuncIt), Block); } +} - // Create allocas for the outputs - unsigned ScalarOutputArgNo = 0; - for (Value *output : outputs) { - if (AggregateArgs && !ExcludeArgsFromAggregate.contains(output)) { - StructValues.insert(output); - } else { - AllocaInst *alloca = - new AllocaInst(output->getType(), DL.getAllocaAddrSpace(), - nullptr, output->getName() + ".loc", - codeReplacer->getParent()->front().begin()); - ReloadOutputs.push_back(alloca); - params.push_back(alloca); - ++ScalarOutputArgNo; - } - } +void CodeExtractor::calculateNewCallTerminatorWeights( + BasicBlock *CodeReplacer, + const DenseMap &ExitWeights, + BranchProbabilityInfo *BPI) { + using Distribution = BlockFrequencyInfoImplBase::Distribution; + using BlockNode = BlockFrequencyInfoImplBase::BlockNode; - StructType *StructArgTy = nullptr; - AllocaInst *Struct = nullptr; - unsigned NumAggregatedInputs = 0; - if (AggregateArgs && !StructValues.empty()) { - std::vector ArgTypes; - for (Value *V : StructValues) - ArgTypes.push_back(V->getType()); - - // Allocate a struct at the beginning of this function - StructArgTy = StructType::get(newFunction->getContext(), ArgTypes); - Struct = new AllocaInst( - StructArgTy, DL.getAllocaAddrSpace(), nullptr, "structArg", - AllocationBlock ? AllocationBlock->getFirstInsertionPt() - : codeReplacer->getParent()->front().begin()); + // Update the branch weights for the exit block. + Instruction *TI = CodeReplacer->getTerminator(); + SmallVector BranchWeights(TI->getNumSuccessors(), 0); - if (ArgsInZeroAddressSpace && DL.getAllocaAddrSpace() != 0) { - auto *StructSpaceCast = new AddrSpaceCastInst( - Struct, PointerType ::get(Context, 0), "structArg.ascast"); - StructSpaceCast->insertAfter(Struct); - params.push_back(StructSpaceCast); - } else { - params.push_back(Struct); - } - // Store aggregated inputs in the struct. - for (unsigned i = 0, e = StructValues.size(); i != e; ++i) { - if (inputs.contains(StructValues[i])) { - Value *Idx[2]; - Idx[0] = Constant::getNullValue(Type::getInt32Ty(Context)); - Idx[1] = ConstantInt::get(Type::getInt32Ty(Context), i); - GetElementPtrInst *GEP = GetElementPtrInst::Create( - StructArgTy, Struct, Idx, "gep_" + StructValues[i]->getName()); - GEP->insertInto(codeReplacer, codeReplacer->end()); - new StoreInst(StructValues[i], GEP, codeReplacer); - NumAggregatedInputs++; - } - } - } + // Block Frequency distribution with dummy node. + Distribution BranchDist; - // Emit the call to the function - call = CallInst::Create(newFunction, params, - NumExitBlocks > 1 ? "targetBlock" : ""); - // Add debug location to the new call, if the original function has debug - // info. In that case, the terminator of the entry block of the extracted - // function contains the first debug location of the extracted function, - // set in extractCodeRegion. - if (codeReplacer->getParent()->getSubprogram()) { - if (auto DL = newFunction->getEntryBlock().getTerminator()->getDebugLoc()) - call->setDebugLoc(DL); - } - call->insertInto(codeReplacer, codeReplacer->end()); + SmallVector EdgeProbabilities( + TI->getNumSuccessors(), BranchProbability::getUnknown()); - // Set swifterror parameter attributes. - for (unsigned SwiftErrArgNo : SwiftErrorArgs) { - call->addParamAttr(SwiftErrArgNo, Attribute::SwiftError); - newFunction->addParamAttr(SwiftErrArgNo, Attribute::SwiftError); + // Add each of the frequencies of the successors. + for (unsigned i = 0, e = TI->getNumSuccessors(); i < e; ++i) { + BlockNode ExitNode(i); + uint64_t ExitFreq = ExitWeights.lookup(TI->getSuccessor(i)).getFrequency(); + if (ExitFreq != 0) + BranchDist.addExit(ExitNode, ExitFreq); + else + EdgeProbabilities[i] = BranchProbability::getZero(); } - // Reload the outputs passed in by reference, use the struct if output is in - // the aggregate or reload from the scalar argument. - for (unsigned i = 0, e = outputs.size(), scalarIdx = 0, - aggIdx = NumAggregatedInputs; - i != e; ++i) { - Value *Output = nullptr; - if (AggregateArgs && StructValues.contains(outputs[i])) { - Value *Idx[2]; - Idx[0] = Constant::getNullValue(Type::getInt32Ty(Context)); - Idx[1] = ConstantInt::get(Type::getInt32Ty(Context), aggIdx); - GetElementPtrInst *GEP = GetElementPtrInst::Create( - StructArgTy, Struct, Idx, "gep_reload_" + outputs[i]->getName()); - GEP->insertInto(codeReplacer, codeReplacer->end()); - Output = GEP; - ++aggIdx; - } else { - Output = ReloadOutputs[scalarIdx]; - ++scalarIdx; - } - LoadInst *load = new LoadInst(outputs[i]->getType(), Output, - outputs[i]->getName() + ".reload", - codeReplacer); - Reloads.push_back(load); - std::vector Users(outputs[i]->user_begin(), outputs[i]->user_end()); - for (User *U : Users) { - Instruction *inst = cast(U); - if (!Blocks.count(inst->getParent())) - inst->replaceUsesOfWith(outputs[i], load); - } + // Check for no total weight. + if (BranchDist.Total == 0) { + BPI->setEdgeProbability(CodeReplacer, EdgeProbabilities); + return; } - // Now we can emit a switch statement using the call as a value. - SwitchInst *TheSwitch = - SwitchInst::Create(Constant::getNullValue(Type::getInt16Ty(Context)), - codeReplacer, 0, codeReplacer); - - // Since there may be multiple exits from the original region, make the new - // function return an unsigned, switch on that number. This loop iterates - // over all of the blocks in the extracted region, updating any terminator - // instructions in the to-be-extracted region that branch to blocks that are - // not in the region to be extracted. - std::map ExitBlockMap; - - // Iterate over the previously collected targets, and create new blocks inside - // the function to branch to. - unsigned switchVal = 0; - for (BasicBlock *OldTarget : OldTargets) { - if (Blocks.count(OldTarget)) - continue; - BasicBlock *&NewTarget = ExitBlockMap[OldTarget]; - if (NewTarget) - continue; - - // If we don't already have an exit stub for this non-extracted - // destination, create one now! - NewTarget = BasicBlock::Create(Context, - OldTarget->getName() + ".exitStub", - newFunction); - unsigned SuccNum = switchVal++; - - Value *brVal = nullptr; - assert(NumExitBlocks < 0xffff && "too many exit blocks for switch"); - switch (NumExitBlocks) { - case 0: - case 1: break; // No value needed. - case 2: // Conditional branch, return a bool - brVal = ConstantInt::get(Type::getInt1Ty(Context), !SuccNum); - break; - default: - brVal = ConstantInt::get(Type::getInt16Ty(Context), SuccNum); - break; - } + // Normalize the distribution so that they can fit in unsigned. + BranchDist.normalize(); - ReturnInst::Create(Context, brVal, NewTarget); + // Create normalized branch weights and set the metadata. + for (unsigned I = 0, E = BranchDist.Weights.size(); I < E; ++I) { + const auto &Weight = BranchDist.Weights[I]; - // Update the switch instruction. - TheSwitch->addCase(ConstantInt::get(Type::getInt16Ty(Context), - SuccNum), - OldTarget); + // Get the weight and update the current BFI. + BranchWeights[Weight.TargetNode.Index] = Weight.Amount; + BranchProbability BP(Weight.Amount, BranchDist.Total); + EdgeProbabilities[Weight.TargetNode.Index] = BP; } + BPI->setEdgeProbability(CodeReplacer, EdgeProbabilities); + TI->setMetadata( + LLVMContext::MD_prof, + MDBuilder(TI->getContext()).createBranchWeights(BranchWeights)); +} - for (BasicBlock *Block : Blocks) { - Instruction *TI = Block->getTerminator(); - for (unsigned i = 0, e = TI->getNumSuccessors(); i != e; ++i) { - if (Blocks.count(TI->getSuccessor(i))) - continue; - BasicBlock *OldTarget = TI->getSuccessor(i); - // add a new basic block which returns the appropriate value - BasicBlock *NewTarget = ExitBlockMap[OldTarget]; - assert(NewTarget && "Unknown target block!"); - - // rewrite the original branch instruction with this new target - TI->setSuccessor(i, NewTarget); - } +/// Erase debug info intrinsics which refer to values in \p F but aren't in +/// \p F. +static void eraseDebugIntrinsicsWithNonLocalRefs(Function &F) { + for (Instruction &I : instructions(F)) { + SmallVector DbgUsers; + SmallVector DbgVariableRecords; + findDbgUsers(DbgUsers, &I, &DbgVariableRecords); + for (DbgVariableIntrinsic *DVI : DbgUsers) + if (DVI->getFunction() != &F) + DVI->eraseFromParent(); + for (DbgVariableRecord *DVR : DbgVariableRecords) + if (DVR->getFunction() != &F) + DVR->eraseFromParent(); } +} - // Store the arguments right after the definition of output value. - // This should be proceeded after creating exit stubs to be ensure that invoke - // result restore will be placed in the outlined function. - Function::arg_iterator ScalarOutputArgBegin = newFunction->arg_begin(); - std::advance(ScalarOutputArgBegin, ScalarInputArgNo); - Function::arg_iterator AggOutputArgBegin = newFunction->arg_begin(); - std::advance(AggOutputArgBegin, ScalarInputArgNo + ScalarOutputArgNo); - - for (unsigned i = 0, e = outputs.size(), aggIdx = NumAggregatedInputs; i != e; - ++i) { - auto *OutI = dyn_cast(outputs[i]); - if (!OutI) - continue; +/// Fix up the debug info in the old and new functions by pointing line +/// locations and debug intrinsics to the new subprogram scope, and by deleting +/// intrinsics which point to values outside of the new function. +static void fixupDebugInfoPostExtraction(Function &OldFunc, Function &NewFunc, + CallInst &TheCall) { + DISubprogram *OldSP = OldFunc.getSubprogram(); + LLVMContext &Ctx = OldFunc.getContext(); - // Find proper insertion point. - BasicBlock::iterator InsertPt; - // In case OutI is an invoke, we insert the store at the beginning in the - // 'normal destination' BB. Otherwise we insert the store right after OutI. - if (auto *InvokeI = dyn_cast(OutI)) - InsertPt = InvokeI->getNormalDest()->getFirstInsertionPt(); - else if (auto *Phi = dyn_cast(OutI)) - InsertPt = Phi->getParent()->getFirstInsertionPt(); - else - InsertPt = std::next(OutI->getIterator()); - - assert((InsertPt->getFunction() == newFunction || - Blocks.count(InsertPt->getParent())) && - "InsertPt should be in new function"); - if (AggregateArgs && StructValues.contains(outputs[i])) { - assert(AggOutputArgBegin != newFunction->arg_end() && - "Number of aggregate output arguments should match " - "the number of defined values"); - Value *Idx[2]; - Idx[0] = Constant::getNullValue(Type::getInt32Ty(Context)); - Idx[1] = ConstantInt::get(Type::getInt32Ty(Context), aggIdx); - GetElementPtrInst *GEP = GetElementPtrInst::Create( - StructArgTy, &*AggOutputArgBegin, Idx, "gep_" + outputs[i]->getName(), - InsertPt); - new StoreInst(outputs[i], GEP, InsertPt); - ++aggIdx; - // Since there should be only one struct argument aggregating - // all the output values, we shouldn't increment AggOutputArgBegin, which - // always points to the struct argument, in this case. - } else { - assert(ScalarOutputArgBegin != newFunction->arg_end() && - "Number of scalar output arguments should match " - "the number of defined values"); - new StoreInst(outputs[i], &*ScalarOutputArgBegin, InsertPt); - ++ScalarOutputArgBegin; - } - } - - // Now that we've done the deed, simplify the switch instruction. - Type *OldFnRetTy = TheSwitch->getParent()->getParent()->getReturnType(); - switch (NumExitBlocks) { - case 0: - // There are no successors (the block containing the switch itself), which - // means that previously this was the last part of the function, and hence - // this should be rewritten as a `ret` or `unreachable`. - if (newFunction->doesNotReturn()) { - // If fn is no return, end with an unreachable terminator. - (void)new UnreachableInst(Context, TheSwitch->getIterator()); - } else if (OldFnRetTy->isVoidTy()) { - // We have no return value. - ReturnInst::Create(Context, nullptr, - TheSwitch->getIterator()); // Return void - } else if (OldFnRetTy == TheSwitch->getCondition()->getType()) { - // return what we have - ReturnInst::Create(Context, TheSwitch->getCondition(), - TheSwitch->getIterator()); - } else { - // Otherwise we must have code extracted an unwind or something, just - // return whatever we want. - ReturnInst::Create(Context, Constant::getNullValue(OldFnRetTy), - TheSwitch->getIterator()); - } - - TheSwitch->eraseFromParent(); - break; - case 1: - // Only a single destination, change the switch into an unconditional - // branch. - BranchInst::Create(TheSwitch->getSuccessor(1), TheSwitch->getIterator()); - TheSwitch->eraseFromParent(); - break; - case 2: - BranchInst::Create(TheSwitch->getSuccessor(1), TheSwitch->getSuccessor(2), - call, TheSwitch->getIterator()); - TheSwitch->eraseFromParent(); - break; - default: - // Otherwise, make the default destination of the switch instruction be one - // of the other successors. - TheSwitch->setCondition(call); - TheSwitch->setDefaultDest(TheSwitch->getSuccessor(NumExitBlocks)); - // Remove redundant case - TheSwitch->removeCase(SwitchInst::CaseIt(TheSwitch, NumExitBlocks-1)); - break; - } - - // Insert lifetime markers around the reloads of any output values. The - // allocas output values are stored in are only in-use in the codeRepl block. - insertLifetimeMarkersSurroundingCall(M, ReloadOutputs, ReloadOutputs, call); - - return call; -} - -void CodeExtractor::moveCodeToFunction(Function *newFunction) { - auto newFuncIt = newFunction->front().getIterator(); - for (BasicBlock *Block : Blocks) { - // Delete the basic block from the old function, and the list of blocks - Block->removeFromParent(); - - // Insert this basic block into the new function - // Insert the original blocks after the entry block created - // for the new function. The entry block may be followed - // by a set of exit blocks at this point, but these exit - // blocks better be placed at the end of the new function. - newFuncIt = newFunction->insert(std::next(newFuncIt), Block); - } -} - -void CodeExtractor::calculateNewCallTerminatorWeights( - BasicBlock *CodeReplacer, - DenseMap &ExitWeights, - BranchProbabilityInfo *BPI) { - using Distribution = BlockFrequencyInfoImplBase::Distribution; - using BlockNode = BlockFrequencyInfoImplBase::BlockNode; - - // Update the branch weights for the exit block. - Instruction *TI = CodeReplacer->getTerminator(); - SmallVector BranchWeights(TI->getNumSuccessors(), 0); - - // Block Frequency distribution with dummy node. - Distribution BranchDist; - - SmallVector EdgeProbabilities( - TI->getNumSuccessors(), BranchProbability::getUnknown()); - - // Add each of the frequencies of the successors. - for (unsigned i = 0, e = TI->getNumSuccessors(); i < e; ++i) { - BlockNode ExitNode(i); - uint64_t ExitFreq = ExitWeights[TI->getSuccessor(i)].getFrequency(); - if (ExitFreq != 0) - BranchDist.addExit(ExitNode, ExitFreq); - else - EdgeProbabilities[i] = BranchProbability::getZero(); - } - - // Check for no total weight. - if (BranchDist.Total == 0) { - BPI->setEdgeProbability(CodeReplacer, EdgeProbabilities); - return; - } - - // Normalize the distribution so that they can fit in unsigned. - BranchDist.normalize(); - - // Create normalized branch weights and set the metadata. - for (unsigned I = 0, E = BranchDist.Weights.size(); I < E; ++I) { - const auto &Weight = BranchDist.Weights[I]; - - // Get the weight and update the current BFI. - BranchWeights[Weight.TargetNode.Index] = Weight.Amount; - BranchProbability BP(Weight.Amount, BranchDist.Total); - EdgeProbabilities[Weight.TargetNode.Index] = BP; - } - BPI->setEdgeProbability(CodeReplacer, EdgeProbabilities); - TI->setMetadata( - LLVMContext::MD_prof, - MDBuilder(TI->getContext()).createBranchWeights(BranchWeights)); -} - -/// Erase debug info intrinsics which refer to values in \p F but aren't in -/// \p F. -static void eraseDebugIntrinsicsWithNonLocalRefs(Function &F) { - for (Instruction &I : instructions(F)) { - SmallVector DbgUsers; - SmallVector DbgVariableRecords; - findDbgUsers(DbgUsers, &I, &DbgVariableRecords); - for (DbgVariableIntrinsic *DVI : DbgUsers) - if (DVI->getFunction() != &F) - DVI->eraseFromParent(); - for (DbgVariableRecord *DVR : DbgVariableRecords) - if (DVR->getFunction() != &F) - DVR->eraseFromParent(); - } -} - -/// Fix up the debug info in the old and new functions by pointing line -/// locations and debug intrinsics to the new subprogram scope, and by deleting -/// intrinsics which point to values outside of the new function. -static void fixupDebugInfoPostExtraction(Function &OldFunc, Function &NewFunc, - CallInst &TheCall) { - DISubprogram *OldSP = OldFunc.getSubprogram(); - LLVMContext &Ctx = OldFunc.getContext(); - - if (!OldSP) { - // Erase any debug info the new function contains. - stripDebugInfo(NewFunc); - // Make sure the old function doesn't contain any non-local metadata refs. - eraseDebugIntrinsicsWithNonLocalRefs(NewFunc); - return; - } + if (!OldSP) { + // Erase any debug info the new function contains. + stripDebugInfo(NewFunc); + // Make sure the old function doesn't contain any non-local metadata refs. + eraseDebugIntrinsicsWithNonLocalRefs(NewFunc); + return; + } // Create a subprogram for the new function. Leave out a description of the // function arguments, as the parameters don't correspond to anything at the @@ -1722,9 +1392,51 @@ CodeExtractor::extractCodeRegion(const CodeExtractorAnalysisCache &CEAC, BasicBlock *header = *Blocks.begin(); Function *oldFunction = header->getParent(); + normalizeCFGForExtraction(header); + + if (!KeepOldBlocks) { + // Remove @llvm.assume calls that will be moved to the new function from the + // old function's assumption cache. + for (BasicBlock *Block : Blocks) { + for (Instruction &I : llvm::make_early_inc_range(*Block)) { + if (auto *AI = dyn_cast(&I)) { + if (AC) + AC->unregisterAssumption(AI); + AI->eraseFromParent(); + } + } + } + } + + ValueSet SinkingCands, HoistingCands; + BasicBlock *CommonExit = nullptr; + findAllocas(CEAC, SinkingCands, HoistingCands, CommonExit); + assert(HoistingCands.empty() || CommonExit); + + // Find inputs to, outputs from the code region. + findInputsOutputs(inputs, outputs, SinkingCands); + + // Collect objects which are inputs to the extraction region and also + // referenced by lifetime start markers within it. The effects of these + // markers must be replicated in the calling function to prevent the stack + // coloring pass from merging slots which store input objects. + ValueSet LifetimesStart; + eraseLifetimeMarkersOnInputs(Blocks, SinkingCands, LifetimesStart); + + if (!HoistingCands.empty()) { + auto *HoistToBlock = findOrCreateBlockForHoisting(CommonExit); + Instruction *TI = HoistToBlock->getTerminator(); + for (auto *II : HoistingCands) + cast(II)->moveBefore(TI); + recomputeSwitchCases(); + } + + // CFG/ExitBlocks must not change hereafter + // Calculate the entry frequency of the new function before we change the root // block. BlockFrequency EntryFreq; + DenseMap ExitWeights; if (BFI) { assert(BPI && "Both BPI and BFI are required to preserve profile info"); for (BasicBlock *Pred : predecessors(header)) { @@ -1733,196 +1445,728 @@ CodeExtractor::extractCodeRegion(const CodeExtractorAnalysisCache &CEAC, EntryFreq += BFI->getBlockFreq(Pred) * BPI->getEdgeProbability(Pred, header); } - } - // Remove @llvm.assume calls that will be moved to the new function from the - // old function's assumption cache. - for (BasicBlock *Block : Blocks) { - for (Instruction &I : llvm::make_early_inc_range(*Block)) { - if (auto *AI = dyn_cast(&I)) { - if (AC) - AC->unregisterAssumption(AI); - AI->eraseFromParent(); + for (BasicBlock *Succ : SwitchCases) { + for (BasicBlock *Block : predecessors(Succ)) { + if (!Blocks.count(Block)) + continue; + + // Update the branch weight for this successor. + BlockFrequency &BF = ExitWeights[Succ]; + BF += BFI->getBlockFreq(Block) * BPI->getEdgeProbability(Block, Succ); } } } + // Determine position for the replacement code. Do so before header is moved + // to the new function. + BasicBlock *ReplIP = header; + if (!KeepOldBlocks) { + while (ReplIP && Blocks.count(ReplIP)) + ReplIP = ReplIP->getNextNode(); + } + + // Construct new function based on inputs/outputs & add allocas for all defs. + std::string SuffixToUse = + Suffix.empty() + ? (header->getName().empty() ? "extracted" : header->getName().str()) + : Suffix; + + ValueSet StructValues; + StructType *StructTy; + Function *newFunction = constructFunctionDeclaration( + inputs, outputs, EntryFreq, oldFunction->getName() + "." + SuffixToUse, + StructValues, StructTy); + newFunction->IsNewDbgInfoFormat = oldFunction->IsNewDbgInfoFormat; + + emitFunctionBody(inputs, outputs, StructValues, newFunction, StructTy, header, + SinkingCands); + + std::vector Reloads; + CallInst *TheCall = emitReplacerCall( + inputs, outputs, StructValues, newFunction, StructTy, oldFunction, ReplIP, + EntryFreq, LifetimesStart.getArrayRef(), Reloads); + + insertReplacerCall(oldFunction, header, TheCall->getParent(), outputs, + Reloads, ExitWeights); + + fixupDebugInfoPostExtraction(*oldFunction, *newFunction, *TheCall); + + LLVM_DEBUG(if (verifyFunction(*newFunction, &errs())) { + newFunction->dump(); + report_fatal_error("verification of newFunction failed!"); + }); + LLVM_DEBUG(if (verifyFunction(*oldFunction)) + report_fatal_error("verification of oldFunction failed!")); + LLVM_DEBUG(if (AC && verifyAssumptionCache(*oldFunction, *newFunction, AC)) + report_fatal_error("Stale Asumption cache for old Function!")); + return newFunction; +} + +void CodeExtractor::normalizeCFGForExtraction(BasicBlock *&header) { // If we have any return instructions in the region, split those blocks so // that the return is not in the region. splitReturnBlocks(); - // Calculate the exit blocks for the extracted region and the total exit - // weights for each of those blocks. - DenseMap ExitWeights; - SetVector ExitBlocks; - for (BasicBlock *Block : Blocks) { - for (BasicBlock *Succ : successors(Block)) { - if (!Blocks.count(Succ)) { - // Update the branch weight for this successor. - if (BFI) { - BlockFrequency &BF = ExitWeights[Succ]; - BF += BFI->getBlockFreq(Block) * BPI->getEdgeProbability(Block, Succ); + // If we have to split PHI nodes of the entry or exit blocks, do so now. + severSplitPHINodesOfEntry(header); + + // If a PHI in an exit block has multiple invoming values from the outlined + // region, create a new PHI for those values within the region such that only + // PHI itself becomes an output value, not each of its incoming values + // individually. + recomputeSwitchCases(); + severSplitPHINodesOfExits(); + + // If the option was given, ensure there are no PHI nodes at all in the exit + // nodes themselves. + if (KeepOldBlocks) { + for (BasicBlock *Block : Blocks) { + for (BasicBlock *Succ : make_early_inc_range(successors(Block))) { + if (Blocks.count(Succ)) + continue; + + if (!Succ->getSinglePredecessor()) + Succ = SplitEdge(Block, Succ, DT); + + // Ensure no PHI node in exit block (still possible with single + // predecessor, e.g. LCSSA) + while (auto *P = dyn_cast(&Succ->front())) { + assert(P->getNumIncomingValues() == 1); + P->replaceAllUsesWith(P->getIncomingValue(0)); + P->eraseFromParent(); } - ExitBlocks.insert(Succ); } } + + // Exit nodes may have changed by SplitEdge. + recomputeSwitchCases(); } - NumExitBlocks = ExitBlocks.size(); +} + +void CodeExtractor::recomputeSwitchCases() { + SwitchCases.clear(); + SmallPtrSet ExitBlocks; for (BasicBlock *Block : Blocks) { - for (BasicBlock *OldTarget : successors(Block)) - if (!Blocks.contains(OldTarget)) - OldTargets.push_back(OldTarget); + for (BasicBlock *Succ : successors(Block)) { + if (Blocks.count(Succ)) + continue; + + bool IsNew = ExitBlocks.insert(Succ).second; + if (IsNew) + SwitchCases.push_back(Succ); + } } +} - // If we have to split PHI nodes of the entry or exit blocks, do so now. - severSplitPHINodesOfEntry(header); - severSplitPHINodesOfExits(ExitBlocks); +Type *CodeExtractor::getSwitchType() { + LLVMContext &Context = Blocks.front()->getContext(); - // This takes place of the original loop - BasicBlock *codeReplacer = BasicBlock::Create(header->getContext(), - "codeRepl", oldFunction, - header); - codeReplacer->IsNewDbgInfoFormat = oldFunction->IsNewDbgInfoFormat; + assert(SwitchCases.size() < 0xffff && "too many exit blocks for switch"); + switch (SwitchCases.size()) { + case 0: + case 1: + return Type::getVoidTy(Context); + case 2: + // Conditional branch, return a bool + return Type::getInt1Ty(Context); + default: + return Type::getInt16Ty(Context); + } +} + +void CodeExtractor::emitFunctionBody( + const ValueSet &inputs, const ValueSet &outputs, + const ValueSet &StructValues, Function *newFunction, + StructType *StructArgTy, BasicBlock *header, const ValueSet &SinkingCands) { + Function *oldFunction = header->getParent(); + LLVMContext &Context = oldFunction->getContext(); // The new function needs a root node because other nodes can branch to the // head of the region, but the entry node of a function cannot have preds. - BasicBlock *newFuncRoot = BasicBlock::Create(header->getContext(), - "newFuncRoot"); + BasicBlock *newFuncRoot = + BasicBlock::Create(Context, "newFuncRoot", newFunction); newFuncRoot->IsNewDbgInfoFormat = oldFunction->IsNewDbgInfoFormat; - auto *BranchI = BranchInst::Create(header); - // If the original function has debug info, we have to add a debug location - // to the new branch instruction from the artificial entry block. - // We use the debug location of the first instruction in the extracted - // blocks, as there is no other equivalent line in the source code. - if (oldFunction->getSubprogram()) { - any_of(Blocks, [&BranchI](const BasicBlock *BB) { - return any_of(*BB, [&BranchI](const Instruction &I) { - if (!I.getDebugLoc()) - return false; - // Don't use source locations attached to debug-intrinsics: they could - // be from completely unrelated scopes. - if (isa(I)) - return false; - BranchI->setDebugLoc(I.getDebugLoc()); - return true; - }); - }); - } - BranchI->insertInto(newFuncRoot, newFuncRoot->end()); - - ValueSet SinkingCands, HoistingCands; - BasicBlock *CommonExit = nullptr; - findAllocas(CEAC, SinkingCands, HoistingCands, CommonExit); - assert(HoistingCands.empty() || CommonExit); + // The map of values from the original function to the corresponding values in + // the extracted function; only used with KeepOldBlocks. + ValueToValueMapTy VMap; + + // Additional instructions not in a extracted block whose operands need to be + // remapped. + SmallVector AdditionalRemap; + + // Copy or move (depending on KeepOldBlocks) an instruction to the new + // function. + auto MoveOrCopyInst = [this, newFuncRoot, &VMap, + &AdditionalRemap](Instruction *I) -> Instruction * { + BasicBlock::iterator IP = newFuncRoot->getFirstInsertionPt(); + if (!KeepOldBlocks) { + I->moveBefore(*newFuncRoot, IP); + return I; + } - // Find inputs to, outputs from the code region. - findInputsOutputs(inputs, outputs, SinkingCands); + Instruction *ClonedI = I->clone(); + ClonedI->setName(I->getName()); + ClonedI->insertInto(newFuncRoot, IP); + AdditionalRemap.push_back(ClonedI); + VMap[I] = ClonedI; + return ClonedI; + }; // Now sink all instructions which only have non-phi uses inside the region. // Group the allocas at the start of the block, so that any bitcast uses of // the allocas are well-defined. - AllocaInst *FirstSunkAlloca = nullptr; for (auto *II : SinkingCands) { - if (auto *AI = dyn_cast(II)) { - AI->moveBefore(*newFuncRoot, newFuncRoot->getFirstInsertionPt()); - if (!FirstSunkAlloca) - FirstSunkAlloca = AI; + if (!isa(II)) { + MoveOrCopyInst(cast(II)); } } - assert((SinkingCands.empty() || FirstSunkAlloca) && - "Did not expect a sink candidate without any allocas"); for (auto *II : SinkingCands) { - if (!isa(II)) { - cast(II)->moveAfter(FirstSunkAlloca); + if (auto *AI = dyn_cast(II)) { + MoveOrCopyInst(AI); } } - if (!HoistingCands.empty()) { - auto *HoistToBlock = findOrCreateBlockForHoisting(CommonExit); - Instruction *TI = HoistToBlock->getTerminator(); - for (auto *II : HoistingCands) - cast(II)->moveBefore(TI); + Function::arg_iterator ScalarAI = newFunction->arg_begin(); + Argument *AggArg = StructValues.empty() + ? nullptr + : newFunction->getArg(newFunction->arg_size() - 1); + + // Rewrite all users of the inputs in the extracted region to use the + // arguments (or appropriate addressing into struct) instead. + SmallVector NewValues; + for (unsigned i = 0, e = inputs.size(), aggIdx = 0; i != e; ++i) { + Value *RewriteVal; + if (StructValues.contains(inputs[i])) { + Value *Idx[2]; + Idx[0] = Constant::getNullValue(Type::getInt32Ty(header->getContext())); + Idx[1] = ConstantInt::get(Type::getInt32Ty(header->getContext()), aggIdx); + GetElementPtrInst *GEP = GetElementPtrInst::Create( + StructArgTy, AggArg, Idx, "gep_" + inputs[i]->getName(), newFuncRoot); + RewriteVal = new LoadInst(StructArgTy->getElementType(aggIdx), GEP, + "loadgep_" + inputs[i]->getName(), newFuncRoot); + ++aggIdx; + } else + RewriteVal = &*ScalarAI++; + + NewValues.push_back(RewriteVal); } - // Collect objects which are inputs to the extraction region and also - // referenced by lifetime start markers within it. The effects of these - // markers must be replicated in the calling function to prevent the stack - // coloring pass from merging slots which store input objects. - ValueSet LifetimesStart; - eraseLifetimeMarkersOnInputs(Blocks, SinkingCands, LifetimesStart); + if (KeepOldBlocks) { + // Copy blocks and instrutions to newFunction. + for (BasicBlock *Block : Blocks) { + BasicBlock *CBB = CloneBasicBlock( + Block, VMap, {}, newFunction, /* CodeInfo */ nullptr, + /* DIFinder */ nullptr, + [](const Instruction *I) -> bool { return !isa(I); }); + + // Add basic block mapping. + VMap[Block] = CBB; + + // It is only legal to clone a function if a block address within that + // function is never referenced outside of the function. Given that, we + // want to map block addresses from the old function to block addresses in + // the clone. (This is different from the generic ValueMapper + // implementation, which generates an invalid blockaddress when + // cloning a function.) + if (Block->hasAddressTaken()) { + Constant *OldBBAddr = BlockAddress::get(oldFunction, Block); + VMap[OldBBAddr] = BlockAddress::get(newFunction, CBB); + } - // Construct new function based on inputs/outputs & add allocas for all defs. - Function *newFunction = - constructFunction(inputs, outputs, header, newFuncRoot, codeReplacer, - oldFunction, oldFunction->getParent()); + // Non-header block may have branches from outside the region. These + // continue to branch to the original blocks, hence remove their PHI + // entries. + if (Block != header) + for (auto &&P : CBB->phis()) { + unsigned NumIncoming = P.getNumIncomingValues(); + for (int Idx = NumIncoming - 1; Idx >= 0; --Idx) { + BasicBlock *IncomingBlock = P.getIncomingBlock(Idx); + if (Blocks.count(IncomingBlock)) + continue; + P.removeIncomingValue(Idx, /*DeletePHIIfEmpty=*/false); + } + } + } + + for (auto P : enumerate(inputs)) + VMap[P.value()] = NewValues[P.index()]; + + } else { + moveCodeToFunction(newFunction); + + for (unsigned i = 0, e = inputs.size(); i != e; ++i) { + Value *RewriteVal = NewValues[i]; + + std::vector Users(inputs[i]->user_begin(), inputs[i]->user_end()); + for (User *use : Users) + if (Instruction *inst = dyn_cast(use)) + if (Blocks.count(inst->getParent())) + inst->replaceUsesOfWith(inputs[i], RewriteVal); + } + } + + // Since there may be multiple exits from the original region, make the new + // function return an unsigned, switch on that number. This loop iterates + // over all of the blocks in the extracted region, updating any terminator + // instructions in the to-be-extracted region that branch to blocks that are + // not in the region to be extracted. + std::map ExitBlockMap; + + // Iterate over the previously collected targets, and create new blocks inside + // the function to branch to. + for (auto P : enumerate(SwitchCases)) { + BasicBlock *OldTarget = P.value(); + size_t SuccNum = P.index(); + + BasicBlock *NewTarget = BasicBlock::Create( + Context, OldTarget->getName() + ".exitStub", newFunction); + ExitBlockMap[OldTarget] = NewTarget; + if (KeepOldBlocks) + VMap[OldTarget] = NewTarget; + + Value *brVal = nullptr; + Type *RetTy = getSwitchType(); + assert(SwitchCases.size() < 0xffff && "too many exit blocks for switch"); + switch (SwitchCases.size()) { + case 0: + case 1: + // No value needed. + break; + case 2: // Conditional branch, return a bool + brVal = ConstantInt::get(RetTy, !SuccNum); + break; + default: + brVal = ConstantInt::get(RetTy, SuccNum); + break; + } + + ReturnInst::Create(Context, brVal, NewTarget); + } + + for (BasicBlock *Block : Blocks) { + Instruction *TI = Block->getTerminator(); + for (unsigned i = 0, e = TI->getNumSuccessors(); i != e; ++i) { + if (Blocks.count(TI->getSuccessor(i))) + continue; + BasicBlock *OldTarget = TI->getSuccessor(i); + // add a new basic block which returns the appropriate value + BasicBlock *NewTarget = ExitBlockMap[OldTarget]; + assert(NewTarget && "Unknown target block!"); + + if (KeepOldBlocks) { + VMap[OldTarget] = NewTarget; + } else { + // rewrite the original branch instruction with this new target + TI->setSuccessor(i, NewTarget); + } + } + } + + // Update values references to point to the new function. + if (KeepOldBlocks) { + for (BasicBlock *Pred : predecessors(header)) { + if (VMap.count(Pred)) + continue; + VMap[Pred] = newFuncRoot; + } + + for (Instruction *II : AdditionalRemap) + RemapInstruction(II, VMap, RF_NoModuleLevelChanges); + + // Loop over all of the instructions in the new function, fixing up operand + // references as we go. This uses VMap to do all the hard work. + for (BasicBlock *Block : Blocks) { + WeakTrackingVH NewBlock = VMap.lookup(Block); + if (!NewBlock) + continue; + + // Loop over all instructions, fixing each one as we find it... + for (Instruction &II : cast(*NewBlock)) + RemapInstruction(&II, VMap, RF_NoModuleLevelChanges); + } + } else { + // Loop over all of the PHI nodes in the header and exit blocks, and change + // any references to the old incoming edge to be the new incoming edge. + for (BasicBlock::iterator I = header->begin(); isa(I); ++I) { + PHINode *PN = cast(I); + for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) + if (!Blocks.count(PN->getIncomingBlock(i))) + PN->setIncomingBlock(i, newFuncRoot); + } + } + + BasicBlock *NewHeader = + KeepOldBlocks ? cast(VMap.lookup(header)) : header; + assert(NewHeader && "Header must have been cloned/moved to newFunction"); + + // Connect newFunction entry block to new header. + BranchInst *BranchI = BranchInst::Create(NewHeader, newFuncRoot); + applyFirstDebugLoc(oldFunction, Blocks.getArrayRef(), BranchI); + + // Store the arguments right after the definition of output value. + // This should be proceeded after creating exit stubs to be ensure that invoke + // result restore will be placed in the outlined function. + ScalarAI = newFunction->arg_begin(); + unsigned AggIdx = 0; + for (Value *Input : inputs) { + if (StructValues.contains(Input)) + ++AggIdx; + else + ++ScalarAI; + } + + for (Value *Output : outputs) { + if (KeepOldBlocks) + Output = VMap.lookup(Output); + + // Find proper insertion point. + // In case Output is an invoke, we insert the store at the beginning in the + // 'normal destination' BB. Otherwise we insert the store right after + // Output. + BasicBlock::iterator InsertPt; + if (auto *InvokeI = dyn_cast(Output)) + InsertPt = InvokeI->getNormalDest()->getFirstInsertionPt(); + else if (auto *Phi = dyn_cast(Output)) + InsertPt = Phi->getParent()->getFirstInsertionPt(); + else if (auto *OutI = dyn_cast(Output)) + InsertPt = std::next(OutI->getIterator()); + else { + // Globals don't need to be updated, just advance to the next argument. + if (StructValues.contains(Output)) + ++AggIdx; + else + ++ScalarAI; + continue; + } + + assert((InsertPt->getFunction() == newFunction || + Blocks.count(InsertPt->getParent())) && + "InsertPt should be in new function"); + + if (StructValues.contains(Output)) { + assert(AggArg && "Number of aggregate output arguments should match " + "the number of defined values"); + Value *Idx[2]; + Idx[0] = Constant::getNullValue(Type::getInt32Ty(Context)); + Idx[1] = ConstantInt::get(Type::getInt32Ty(Context), AggIdx); + GetElementPtrInst *GEP = GetElementPtrInst::Create( + StructArgTy, AggArg, Idx, "gep_" + Output->getName(), InsertPt); + new StoreInst(Output, GEP, InsertPt); + ++AggIdx; + } else { + assert(ScalarAI != newFunction->arg_end() && + "Number of scalar output arguments should match " + "the number of defined values"); + new StoreInst(Output, &*ScalarAI, InsertPt); + ++ScalarAI; + } + } + + if (SwitchCases.empty()) { + // Mark the new function `noreturn` if applicable. Terminators which resume + // exception propagation are treated as returning instructions. This is to + // avoid inserting traps after calls to outlined functions which unwind. + if (none_of(Blocks, [](const BasicBlock *BB) { + const Instruction *Term = BB->getTerminator(); + return isa(Term) || isa(Term); + })) + newFunction->setDoesNotReturn(); + } +} + +CallInst *CodeExtractor::emitReplacerCall( + const ValueSet &inputs, const ValueSet &outputs, + const ValueSet &StructValues, Function *newFunction, + StructType *StructArgTy, Function *oldFunction, BasicBlock *ReplIP, + BlockFrequency EntryFreq, ArrayRef LifetimesStart, + std::vector &Reloads) { + LLVMContext &Context = oldFunction->getContext(); + Module *M = oldFunction->getParent(); + const DataLayout &DL = M->getDataLayout(); + + // This takes place of the original loop + BasicBlock *codeReplacer = + BasicBlock::Create(Context, "codeRepl", oldFunction, ReplIP); + codeReplacer->IsNewDbgInfoFormat = oldFunction->IsNewDbgInfoFormat; + BasicBlock *AllocaBlock = + AllocationBlock ? AllocationBlock : &oldFunction->getEntryBlock(); + AllocaBlock->IsNewDbgInfoFormat = oldFunction->IsNewDbgInfoFormat; // Update the entry count of the function. - if (BFI) { - auto Count = BFI->getProfileCountFromFreq(EntryFreq); - if (Count) - newFunction->setEntryCount( - ProfileCount(*Count, Function::PCT_Real)); // FIXME + if (BFI) BFI->setBlockFreq(codeReplacer, EntryFreq); + + std::vector params; + + // Add inputs as params, or to be filled into the struct + for (Value *input : inputs) { + if (StructValues.contains(input)) + continue; + + params.push_back(input); + } + + // Create allocas for the outputs + std::vector ReloadOutputs; + for (Value *output : outputs) { + if (StructValues.contains(output)) + continue; + + AllocaInst *alloca = new AllocaInst( + output->getType(), DL.getAllocaAddrSpace(), nullptr, + output->getName() + ".loc", AllocaBlock->getFirstInsertionPt()); + params.push_back(alloca); + ReloadOutputs.push_back(alloca); + } + + AllocaInst *Struct = nullptr; + if (!StructValues.empty()) { + Struct = new AllocaInst(StructArgTy, DL.getAllocaAddrSpace(), nullptr, + "structArg", AllocaBlock->getFirstInsertionPt()); + if (ArgsInZeroAddressSpace && DL.getAllocaAddrSpace() != 0) { + auto *StructSpaceCast = new AddrSpaceCastInst( + Struct, PointerType ::get(Context, 0), "structArg.ascast"); + StructSpaceCast->insertAfter(Struct); + params.push_back(StructSpaceCast); + } else { + params.push_back(Struct); + } + + unsigned AggIdx = 0; + for (Value *input : inputs) { + if (!StructValues.contains(input)) + continue; + + Value *Idx[2]; + Idx[0] = Constant::getNullValue(Type::getInt32Ty(Context)); + Idx[1] = ConstantInt::get(Type::getInt32Ty(Context), AggIdx); + GetElementPtrInst *GEP = GetElementPtrInst::Create( + StructArgTy, Struct, Idx, "gep_" + input->getName()); + GEP->insertInto(codeReplacer, codeReplacer->end()); + new StoreInst(input, GEP, codeReplacer); + + ++AggIdx; + } } - CallInst *TheCall = - emitCallAndSwitchStatement(newFunction, codeReplacer, inputs, outputs); + // Emit the call to the function + CallInst *call = CallInst::Create(newFunction, params, + SwitchCases.size() > 1 ? "targetBlock" : "", + codeReplacer); - moveCodeToFunction(newFunction); + // Set swifterror parameter attributes. + unsigned ParamIdx = 0; + unsigned AggIdx = 0; + for (auto input : inputs) { + if (StructValues.contains(input)) { + ++AggIdx; + } else { + if (input->isSwiftError()) + call->addParamAttr(ParamIdx, Attribute::SwiftError); + ++ParamIdx; + } + } + + // Add debug location to the new call, if the original function has debug + // info. In that case, the terminator of the entry block of the extracted + // function contains the first debug location of the extracted function, + // set in extractCodeRegion. + if (codeReplacer->getParent()->getSubprogram()) { + if (auto DL = newFunction->getEntryBlock().getTerminator()->getDebugLoc()) + call->setDebugLoc(DL); + } + + // Reload the outputs passed in by reference, use the struct if output is in + // the aggregate or reload from the scalar argument. + for (unsigned i = 0, e = outputs.size(), scalarIdx = 0; i != e; ++i) { + Value *Output = nullptr; + if (StructValues.contains(outputs[i])) { + Value *Idx[2]; + Idx[0] = Constant::getNullValue(Type::getInt32Ty(Context)); + Idx[1] = ConstantInt::get(Type::getInt32Ty(Context), AggIdx); + GetElementPtrInst *GEP = GetElementPtrInst::Create( + StructArgTy, Struct, Idx, "gep_reload_" + outputs[i]->getName()); + GEP->insertInto(codeReplacer, codeReplacer->end()); + Output = GEP; + ++AggIdx; + } else { + Output = ReloadOutputs[scalarIdx]; + ++scalarIdx; + } + LoadInst *load = + new LoadInst(outputs[i]->getType(), Output, + outputs[i]->getName() + ".reload", codeReplacer); + Reloads.push_back(load); + } + + // Now we can emit a switch statement using the call as a value. + SwitchInst *TheSwitch = + SwitchInst::Create(Constant::getNullValue(Type::getInt16Ty(Context)), + codeReplacer, 0, codeReplacer); + for (auto P : enumerate(SwitchCases)) { + BasicBlock *OldTarget = P.value(); + size_t SuccNum = P.index(); + + TheSwitch->addCase(ConstantInt::get(Type::getInt16Ty(Context), SuccNum), + OldTarget); + } + + // Now that we've done the deed, simplify the switch instruction. + Type *OldFnRetTy = TheSwitch->getParent()->getParent()->getReturnType(); + switch (SwitchCases.size()) { + case 0: + // There are no successors (the block containing the switch itself), which + // means that previously this was the last part of the function, and hence + // this should be rewritten as a `ret` or `unreachable`. + if (newFunction->doesNotReturn()) { + // If fn is no return, end with an unreachable terminator. + (void)new UnreachableInst(Context, TheSwitch->getIterator()); + } else if (OldFnRetTy->isVoidTy()) { + // We have no return value. + ReturnInst::Create(Context, nullptr, + TheSwitch->getIterator()); // Return void + } else if (OldFnRetTy == TheSwitch->getCondition()->getType()) { + // return what we have + ReturnInst::Create(Context, TheSwitch->getCondition(), + TheSwitch->getIterator()); + } else { + // Otherwise we must have code extracted an unwind or something, just + // return whatever we want. + ReturnInst::Create(Context, Constant::getNullValue(OldFnRetTy), + TheSwitch->getIterator()); + } + + TheSwitch->eraseFromParent(); + break; + case 1: + // Only a single destination, change the switch into an unconditional + // branch. + BranchInst::Create(TheSwitch->getSuccessor(1), TheSwitch->getIterator()); + TheSwitch->eraseFromParent(); + break; + case 2: + // Only two destinations, convert to a condition branch. + // Remark: This also swaps the target branches: + // 0 -> false -> getSuccessor(2); 1 -> true -> getSuccessor(1) + BranchInst::Create(TheSwitch->getSuccessor(1), TheSwitch->getSuccessor(2), + call, TheSwitch->getIterator()); + TheSwitch->eraseFromParent(); + break; + default: + // Otherwise, make the default destination of the switch instruction be one + // of the other successors. + TheSwitch->setCondition(call); + TheSwitch->setDefaultDest(TheSwitch->getSuccessor(SwitchCases.size())); + // Remove redundant case + TheSwitch->removeCase( + SwitchInst::CaseIt(TheSwitch, SwitchCases.size() - 1)); + break; + } + + // Insert lifetime markers around the reloads of any output values. The + // allocas output values are stored in are only in-use in the codeRepl block. + insertLifetimeMarkersSurroundingCall(M, ReloadOutputs, ReloadOutputs, call); // Replicate the effects of any lifetime start/end markers which referenced // input objects in the extraction region by placing markers around the call. - insertLifetimeMarkersSurroundingCall( - oldFunction->getParent(), LifetimesStart.getArrayRef(), {}, TheCall); + insertLifetimeMarkersSurroundingCall(oldFunction->getParent(), LifetimesStart, + {}, call); - // Propagate personality info to the new function if there is one. - if (oldFunction->hasPersonalityFn()) - newFunction->setPersonalityFn(oldFunction->getPersonalityFn()); + return call; +} - // Update the branch weights for the exit block. - if (BFI && NumExitBlocks > 1) - calculateNewCallTerminatorWeights(codeReplacer, ExitWeights, BPI); +void CodeExtractor::insertReplacerCall( + Function *oldFunction, BasicBlock *header, BasicBlock *codeReplacer, + const ValueSet &outputs, ArrayRef Reloads, + const DenseMap &ExitWeights) { - // Loop over all of the PHI nodes in the header and exit blocks, and change - // any references to the old incoming edge to be the new incoming edge. - for (BasicBlock::iterator I = header->begin(); isa(I); ++I) { - PHINode *PN = cast(I); - for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) - if (!Blocks.count(PN->getIncomingBlock(i))) - PN->setIncomingBlock(i, newFuncRoot); - } + // Rewrite branches to basic blocks outside of the loop to new dummy blocks + // within the new function. This must be done before we lose track of which + // blocks were originally in the code region. + std::vector Users(header->user_begin(), header->user_end()); + for (auto &U : Users) + // The BasicBlock which contains the branch is not in the region + // modify the branch target to a new block + if (Instruction *I = dyn_cast(U)) + if (I->isTerminator() && I->getFunction() == oldFunction && + !Blocks.count(I->getParent())) + I->replaceUsesOfWith(header, codeReplacer); + + if (KeepOldBlocks) { + // Change references to output values after the call to use either the value + // written by the extracted function or the original value if we skipped the + // call. Use SSAUpdater to propagate the new PHI since the CFG has changed. + + SSAUpdater SSA; + for (auto P : enumerate(outputs)) { + size_t OutIdx = P.index(); + Instruction *OldVal = cast(P.value()); + Value *NewVal = Reloads[OutIdx]; + + SSA.Initialize(OldVal->getType(), + (OldVal->getName() + ".merge_with_extracted").str()); + SSA.AddAvailableValue(codeReplacer, NewVal); + + // Could help SSAUpdater by determining in advance which output values are + // available in which exit blocks (from DT). + SSA.AddAvailableValue(OldVal->getParent(), OldVal); + + for (Use &U : make_early_inc_range(OldVal->uses())) { + auto *User = dyn_cast(U.getUser()); + if (!User) + continue; + BasicBlock *EffectiveUser = User->getParent(); + if (auto *PHI = dyn_cast(User)) + EffectiveUser = PHI->getIncomingBlock(U); - for (BasicBlock *ExitBB : ExitBlocks) - for (PHINode &PN : ExitBB->phis()) { - Value *IncomingCodeReplacerVal = nullptr; - for (unsigned i = 0, e = PN.getNumIncomingValues(); i != e; ++i) { - // Ignore incoming values from outside of the extracted region. - if (!Blocks.count(PN.getIncomingBlock(i))) + if (EffectiveUser == codeReplacer || Blocks.count(EffectiveUser)) continue; - // Ensure that there is only one incoming value from codeReplacer. - if (!IncomingCodeReplacerVal) { - PN.setIncomingBlock(i, codeReplacer); - IncomingCodeReplacerVal = PN.getIncomingValue(i); - } else - assert(IncomingCodeReplacerVal == PN.getIncomingValue(i) && - "PHI has two incompatbile incoming values from codeRepl"); + SSA.RewriteUseAfterInsertions(U); } } + } else { + // When moving the code region it is sufficient to replace all uses to the + // extracted function values. Since the original definition's block + // dominated its use, it will also be dominated by codeReplacer's switch + // which joined multiple exit blocks. + + for (BasicBlock *ExitBB : SwitchCases) + for (PHINode &PN : ExitBB->phis()) { + Value *IncomingCodeReplacerVal = nullptr; + for (unsigned i = 0, e = PN.getNumIncomingValues(); i != e; ++i) { + // Ignore incoming values from outside of the extracted region. + if (!Blocks.count(PN.getIncomingBlock(i))) + continue; - fixupDebugInfoPostExtraction(*oldFunction, *newFunction, *TheCall); + // Ensure that there is only one incoming value from codeReplacer. + if (!IncomingCodeReplacerVal) { + PN.setIncomingBlock(i, codeReplacer); + IncomingCodeReplacerVal = PN.getIncomingValue(i); + } else + assert(IncomingCodeReplacerVal == PN.getIncomingValue(i) && + "PHI has two incompatbile incoming values from codeRepl"); + } + } - LLVM_DEBUG(if (verifyFunction(*newFunction, &errs())) { - newFunction->dump(); - report_fatal_error("verification of newFunction failed!"); - }); - LLVM_DEBUG(if (verifyFunction(*oldFunction)) - report_fatal_error("verification of oldFunction failed!")); - LLVM_DEBUG(if (AC && verifyAssumptionCache(*oldFunction, *newFunction, AC)) - report_fatal_error("Stale Asumption cache for old Function!")); - return newFunction; + for (unsigned i = 0, e = outputs.size(); i != e; ++i) { + Value *load = Reloads[i]; + std::vector Users(outputs[i]->user_begin(), + outputs[i]->user_end()); + for (User *U : Users) { + Instruction *inst = cast(U); + if (inst->getParent()->getParent() == oldFunction) + inst->replaceUsesOfWith(outputs[i], load); + } + } + } + + // Update the branch weights for the exit block. + if (BFI && SwitchCases.size() > 1) + calculateNewCallTerminatorWeights(codeReplacer, ExitWeights, BPI); } bool CodeExtractor::verifyAssumptionCache(const Function &OldFunc, diff --git a/llvm/test/tools/llvm-extract/extract-block-cleanup.ll b/llvm/test/tools/llvm-extract/extract-block-cleanup.ll new file mode 100644 index 0000000000000..8b44645c4149a --- /dev/null +++ b/llvm/test/tools/llvm-extract/extract-block-cleanup.ll @@ -0,0 +1,116 @@ +; RUN: llvm-extract -S -bb "foo:region_start;extractonly;cleanup;fallback;region_end" --replace-with-call %s | FileCheck %s + + +; CHECK-LABEL: define void @foo(ptr %arg, i1 %c) { +; CHECK-NEXT: entry: +; CHECK-NEXT: br i1 %c, label %codeRepl, label %outsideonly +; CHECK-EMPTY: +; CHECK-NEXT: outsideonly: +; CHECK-NEXT: store i32 0, ptr %arg, align 4 +; CHECK-NEXT: br label %cleanup +; CHECK-EMPTY: +; CHECK-NEXT: codeRepl: +; CHECK-NEXT: %targetBlock = call i1 @foo.region_start(ptr %arg) +; CHECK-NEXT: br i1 %targetBlock, label %cleanup.return_crit_edge, label %region_end.split +; CHECK-EMPTY: +; CHECK-NEXT: region_start: +; CHECK-NEXT: br label %extractonly +; CHECK-EMPTY: +; CHECK-NEXT: extractonly: +; CHECK-NEXT: store i32 1, ptr %arg, align 4 +; CHECK-NEXT: br label %cleanup +; CHECK-EMPTY: +; CHECK-NEXT: cleanup: +; CHECK-NEXT: %dest = phi i8 [ 0, %outsideonly ], [ 1, %extractonly ] +; CHECK-NEXT: switch i8 %dest, label %fallback [ +; CHECK-NEXT: i8 0, label %cleanup.return_crit_edge +; CHECK-NEXT: i8 1, label %region_end +; CHECK-NEXT: ] +; CHECK-EMPTY: +; CHECK-NEXT: cleanup.return_crit_edge: +; CHECK-NEXT: br label %return +; CHECK-EMPTY: +; CHECK-NEXT: fallback: +; CHECK-NEXT: unreachable +; CHECK-EMPTY: +; CHECK-NEXT: region_end: +; CHECK-NEXT: br label %region_end.split +; CHECK-EMPTY: +; CHECK-NEXT: region_end.split: +; CHECK-NEXT: br label %return +; CHECK-EMPTY: +; CHECK-NEXT: outsidecont: +; CHECK-NEXT: br label %return +; CHECK-EMPTY: +; CHECK-NEXT: return: +; CHECK-NEXT: ret void +; CHECK-NEXT: } + + +; CHECK-LABEL: define internal i1 @foo.region_start(ptr %arg) { +; CHECK-NEXT: newFuncRoot: +; CHECK-NEXT: br label %region_start +; CHECK-EMPTY: +; CHECK-NEXT: region_start: +; CHECK-NEXT: br label %extractonly +; CHECK-EMPTY: +; CHECK-NEXT: extractonly: +; CHECK-NEXT: store i32 1, ptr %arg, align 4 +; CHECK-NEXT: br label %cleanup +; CHECK-EMPTY: +; CHECK-NEXT: cleanup: +; CHECK-NEXT: %dest = phi i8 [ 1, %extractonly ] +; CHECK-NEXT: switch i8 %dest, label %fallback [ +; CHECK-NEXT: i8 0, label %cleanup.return_crit_edge.exitStub +; CHECK-NEXT: i8 1, label %region_end +; CHECK-NEXT: ] +; CHECK-EMPTY: +; CHECK-NEXT: fallback: +; CHECK-NEXT: unreachable +; CHECK-EMPTY: +; CHECK-NEXT: region_end: +; CHECK-NEXT: br label %region_end.split.exitStub +; CHECK-EMPTY: +; CHECK-NEXT: cleanup.return_crit_edge.exitStub: +; CHECK-NEXT: ret i1 true +; CHECK-EMPTY: +; CHECK-NEXT: region_end.split.exitStub: +; CHECK-NEXT: ret i1 false +; CHECK-NEXT: } + + + +define void @foo(ptr %arg, i1 %c) { +entry: + br i1 %c, label %region_start, label %outsideonly + +outsideonly: + store i32 0, ptr %arg, align 4 + br label %cleanup + +region_start: + br label %extractonly + +extractonly: + store i32 1, ptr %arg, align 4 + br label %cleanup + +cleanup: + %dest = phi i8 [0, %outsideonly], [1, %extractonly] + switch i8 %dest, label %fallback [ + i8 0, label %return + i8 1, label %region_end + ] + +fallback: + unreachable + +region_end: + br label %return + +outsidecont: + br label %return + +return: + ret void +} diff --git a/llvm/test/tools/llvm-extract/extract-block-multiple-exits.ll b/llvm/test/tools/llvm-extract/extract-block-multiple-exits.ll new file mode 100644 index 0000000000000..b7475aebd7770 --- /dev/null +++ b/llvm/test/tools/llvm-extract/extract-block-multiple-exits.ll @@ -0,0 +1,196 @@ +; RUN: llvm-extract -S -bb "func:region_start;exiting0;exiting1" --replace-with-call %s | FileCheck %s + + +; CHECK-LABEL: define void @func(ptr %arg, i1 %c0, i1 %c1, i1 %c2, i8 %dest) { +; CHECK-NEXT: entry: +; CHECK-NEXT: %B.ce.loc = alloca i32, align 4 +; CHECK-NEXT: %c.loc = alloca i32, align 4 +; CHECK-NEXT: %b.loc = alloca i32, align 4 +; CHECK-NEXT: %a.loc = alloca i32, align 4 +; CHECK-NEXT: br i1 %c0, label %codeRepl, label %exit +; CHECK-EMPTY: +; CHECK-NEXT: codeRepl: +; CHECK-NEXT: call void @llvm.lifetime.start.p0(i64 -1, ptr %a.loc) +; CHECK-NEXT: call void @llvm.lifetime.start.p0(i64 -1, ptr %b.loc) +; CHECK-NEXT: call void @llvm.lifetime.start.p0(i64 -1, ptr %c.loc) +; CHECK-NEXT: call void @llvm.lifetime.start.p0(i64 -1, ptr %B.ce.loc) +; CHECK-NEXT: %targetBlock = call i16 @func.region_start(i1 %c1, i1 %c2, i8 %dest, ptr %a.loc, ptr %b.loc, ptr %c.loc, ptr %B.ce.loc) +; CHECK-NEXT: %a.reload = load i32, ptr %a.loc, align 4 +; CHECK-NEXT: %b.reload = load i32, ptr %b.loc, align 4 +; CHECK-NEXT: %c.reload = load i32, ptr %c.loc, align 4 +; CHECK-NEXT: %B.ce.reload = load i32, ptr %B.ce.loc, align 4 +; CHECK-NEXT: call void @llvm.lifetime.end.p0(i64 -1, ptr %a.loc) +; CHECK-NEXT: call void @llvm.lifetime.end.p0(i64 -1, ptr %b.loc) +; CHECK-NEXT: call void @llvm.lifetime.end.p0(i64 -1, ptr %c.loc) +; CHECK-NEXT: call void @llvm.lifetime.end.p0(i64 -1, ptr %B.ce.loc) +; CHECK-NEXT: switch i16 %targetBlock, label %exit0 [ +; CHECK-NEXT: i16 0, label %exiting0.exit_crit_edge +; CHECK-NEXT: i16 1, label %fallback +; CHECK-NEXT: i16 2, label %exit1 +; CHECK-NEXT: i16 3, label %exit2 +; CHECK-NEXT: ] +; CHECK-EMPTY: +; CHECK-NEXT: region_start: +; CHECK-NEXT: %a = add i32 42, 1 +; CHECK-NEXT: br i1 %c1, label %exiting0, label %exiting1 +; CHECK-EMPTY: +; CHECK-NEXT: exiting0: +; CHECK-NEXT: %b = add i32 42, 2 +; CHECK-NEXT: br i1 %c2, label %exiting0.exit_crit_edge, label %exit0.split +; CHECK-EMPTY: +; CHECK-NEXT: exiting0.exit_crit_edge: +; CHECK-NEXT: %b.merge_with_extracted4 = phi i32 [ %b.reload, %codeRepl ], [ %b, %exiting0 ] +; CHECK-NEXT: br label %exit +; CHECK-EMPTY: +; CHECK-NEXT: exiting1: +; CHECK-NEXT: %c = add i32 42, 3 +; CHECK-NEXT: switch i8 %dest, label %fallback [ +; CHECK-NEXT: i8 0, label %exit0.split +; CHECK-NEXT: i8 1, label %exit1 +; CHECK-NEXT: i8 2, label %exit2 +; CHECK-NEXT: i8 3, label %exit0.split +; CHECK-NEXT: ] +; CHECK-EMPTY: +; CHECK-NEXT: fallback: +; CHECK-NEXT: unreachable +; CHECK-EMPTY: +; CHECK-NEXT: exit: +; CHECK-NEXT: %A = phi i32 [ 42, %entry ], [ %b.merge_with_extracted4, %exiting0.exit_crit_edge ] +; CHECK-NEXT: store i32 %A, ptr %arg, align 4 +; CHECK-NEXT: br label %return +; CHECK-EMPTY: +; CHECK-NEXT: exit0.split: +; CHECK-NEXT: %b.merge_with_extracted3 = phi i32 [ %b, %exiting0 ], [ poison, %exiting1 ], [ poison, %exiting1 ] +; CHECK-NEXT: %B.ce = phi i32 [ %b, %exiting0 ], [ %a, %exiting1 ], [ %a, %exiting1 ] +; CHECK-NEXT: br label %exit0 +; CHECK-EMPTY: +; CHECK-NEXT: exit0: +; CHECK-NEXT: %B.ce.merge_with_extracted = phi i32 [ %B.ce.reload, %codeRepl ], [ %B.ce, %exit0.split ] +; CHECK-NEXT: %b.merge_with_extracted = phi i32 [ %b.reload, %codeRepl ], [ %b.merge_with_extracted3, %exit0.split ] +; CHECK-NEXT: %a.merge_with_extracted2 = phi i32 [ %a.reload, %codeRepl ], [ %a, %exit0.split ] +; CHECK-NEXT: store i32 %a.merge_with_extracted2, ptr %arg, align 4 +; CHECK-NEXT: store i32 %B.ce.merge_with_extracted, ptr %arg, align 4 +; CHECK-NEXT: br label %after +; CHECK-EMPTY: +; CHECK-NEXT: exit1: +; CHECK-NEXT: %c.merge_with_extracted5 = phi i32 [ %c.reload, %codeRepl ], [ %c, %exiting1 ] +; CHECK-NEXT: %a.merge_with_extracted1 = phi i32 [ %a.reload, %codeRepl ], [ %a, %exiting1 ] +; CHECK-NEXT: br label %after +; CHECK-EMPTY: +; CHECK-NEXT: exit2: +; CHECK-NEXT: %c.merge_with_extracted = phi i32 [ %c.reload, %codeRepl ], [ %c, %exiting1 ] +; CHECK-NEXT: store i32 %c.merge_with_extracted, ptr %arg, align 4 +; CHECK-NEXT: store i32 %c.merge_with_extracted, ptr %arg, align 4 +; CHECK-NEXT: br label %return +; CHECK-EMPTY: +; CHECK-NEXT: after: +; CHECK-NEXT: %a.merge_with_extracted = phi i32 [ %a.merge_with_extracted2, %exit0 ], [ %a.merge_with_extracted1, %exit1 ] +; CHECK-NEXT: %D = phi i32 [ %b.merge_with_extracted, %exit0 ], [ %c.merge_with_extracted5, %exit1 ] +; CHECK-NEXT: store i32 %a.merge_with_extracted, ptr %arg, align 4 +; CHECK-NEXT: store i32 %D, ptr %arg, align 4 +; CHECK-NEXT: br label %return +; CHECK-EMPTY: +; CHECK-NEXT: return: +; CHECK-NEXT: ret void +; CHECK-NEXT: } + + +; CHECK-LABEL: define internal i16 @func.region_start(i1 %c1, i1 %c2, i8 %dest, ptr %a.out, ptr %b.out, ptr %c.out, ptr %B.ce.out) { +; CHECK-NEXT: newFuncRoot: +; CHECK-NEXT: br label %region_start +; CHECK-EMPTY: +; CHECK-NEXT: region_start: +; CHECK-NEXT: %a = add i32 42, 1 +; CHECK-NEXT: store i32 %a, ptr %a.out, align 4 +; CHECK-NEXT: br i1 %c1, label %exiting0, label %exiting1 +; CHECK-EMPTY: +; CHECK-NEXT: exiting0: +; CHECK-NEXT: %b = add i32 42, 2 +; CHECK-NEXT: store i32 %b, ptr %b.out, align 4 +; CHECK-NEXT: br i1 %c2, label %exiting0.exit_crit_edge.exitStub, label %exit0.split +; CHECK-EMPTY: +; CHECK-NEXT: exiting1: +; CHECK-NEXT: %c = add i32 42, 3 +; CHECK-NEXT: store i32 %c, ptr %c.out, align 4 +; CHECK-NEXT: switch i8 %dest, label %fallback.exitStub [ +; CHECK-NEXT: i8 0, label %exit0.split +; CHECK-NEXT: i8 1, label %exit1.exitStub +; CHECK-NEXT: i8 2, label %exit2.exitStub +; CHECK-NEXT: i8 3, label %exit0.split +; CHECK-NEXT: ] +; CHECK-EMPTY: +; CHECK-NEXT: exit0.split: +; CHECK-NEXT: %B.ce = phi i32 [ %b, %exiting0 ], [ %a, %exiting1 ], [ %a, %exiting1 ] +; CHECK-NEXT: store i32 %B.ce, ptr %B.ce.out, align 4 +; CHECK-NEXT: br label %exit0.exitStub +; CHECK-EMPTY: +; CHECK-NEXT: exiting0.exit_crit_edge.exitStub: +; CHECK-NEXT: ret i16 0 +; CHECK-EMPTY: +; CHECK-NEXT: fallback.exitStub: +; CHECK-NEXT: ret i16 1 +; CHECK-EMPTY: +; CHECK-NEXT: exit1.exitStub: +; CHECK-NEXT: ret i16 2 +; CHECK-EMPTY: +; CHECK-NEXT: exit2.exitStub: +; CHECK-NEXT: ret i16 3 +; CHECK-EMPTY: +; CHECK-NEXT: exit0.exitStub: +; CHECK-NEXT: ret i16 4 +; CHECK-NEXT: } + + +define void @func(ptr %arg, i1 %c0, i1 %c1, i1 %c2, i8 %dest) { +entry: + br i1 %c0, label %region_start, label %exit + +region_start: + %a = add i32 42, 1 + br i1 %c1, label %exiting0, label %exiting1 + +exiting0: + %b = add i32 42, 2 + br i1 %c2, label %exit, label %exit0 + +exiting1: + %c = add i32 42, 3 + switch i8 %dest, label %fallback [ + i8 0, label %exit0 + i8 1, label %exit1 + i8 2, label %exit2 + i8 3, label %exit0 + ] + +fallback: + unreachable + +exit: + %A = phi i32 [ 42, %entry ], [ %b, %exiting0 ] + store i32 %A, ptr %arg + br label %return + +exit0: + %B = phi i32 [ %b, %exiting0 ], [ %a, %exiting1 ] , [ %a, %exiting1 ] + store i32 %a, ptr %arg + store i32 %B, ptr %arg + br label %after + +exit1: + br label %after + +exit2: + %C = phi i32 [ %c, %exiting1 ] + store i32 %c, ptr %arg + store i32 %C, ptr %arg + br label %return + +after: + %D = phi i32 [ %b, %exit0 ], [ %c, %exit1 ] + store i32 %a, ptr %arg + store i32 %D, ptr %arg + br label %return + +return: + ret void +} diff --git a/llvm/test/tools/llvm-extract/extract-block-sink.ll b/llvm/test/tools/llvm-extract/extract-block-sink.ll new file mode 100644 index 0000000000000..66e0b4a7799f4 --- /dev/null +++ b/llvm/test/tools/llvm-extract/extract-block-sink.ll @@ -0,0 +1,60 @@ +; RUN: llvm-extract -S -bb "foo:region_start" --replace-with-call %s | FileCheck %s + +; CHECK-LABEL: define void @foo() { +; CHECK-NEXT: entry: +; CHECK-NEXT: %a = alloca i32, align 4 +; CHECK-NEXT: %b = alloca i32, align 4 +; CHECK-NEXT: br label %codeRepl +; CHECK-EMPTY: +; CHECK-NEXT: codeRepl: +; CHECK-NEXT: call void @foo.region_start(ptr %b) +; CHECK-NEXT: br label %return +; CHECK-EMPTY: +; CHECK-NEXT: region_start: +; CHECK-NEXT: call void @llvm.lifetime.start.p0(i64 4, ptr nonnull %a) +; CHECK-NEXT: store i32 43, ptr %a, align 4 +; CHECK-NEXT: call void @llvm.lifetime.end.p0(i64 4, ptr nonnull %a) +; CHECK-NEXT: store i32 44, ptr %b, align 4 +; CHECK-NEXT: br label %return +; CHECK-EMPTY: +; CHECK-NEXT: return: +; CHECK-NEXT: ret void +; CHECK-NEXT: } + + +; CHECK-LABEL: define internal void @foo.region_start(ptr %b) { +; CHECK-NEXT: newFuncRoot: +; CHECK-NEXT: %a = alloca i32, align 4 +; CHECK-NEXT: br label %region_start +; CHECK-EMPTY: +; CHECK-NEXT: region_start: +; CHECK-NEXT: call void @llvm.lifetime.start.p0(i64 4, ptr nonnull %a) +; CHECK-NEXT: store i32 43, ptr %a, align 4 +; CHECK-NEXT: call void @llvm.lifetime.end.p0(i64 4, ptr nonnull %a) +; CHECK-NEXT: store i32 44, ptr %b, align 4 +; CHECK-NEXT: br label %return.exitStub +; CHECK-EMPTY: +; CHECK-NEXT: return.exitStub: +; CHECK-NEXT: ret void +; CHECK-NEXT: } + + +declare void @llvm.lifetime.start.p0i32(i64, ptr nocapture) +declare void @llvm.lifetime.end.p0i32(i64, ptr nocapture) + +define void @foo() { +entry: + %a = alloca i32, align 4 + %b = alloca i32, align 4 + br label %region_start + +region_start: + call void @llvm.lifetime.start.p0i32(i64 4, ptr nonnull %a) + store i32 43, ptr %a + call void @llvm.lifetime.end.p0i32(i64 4, ptr nonnull %a) + store i32 44, ptr %b + br label %return + +return: + ret void +} diff --git a/llvm/test/tools/llvm-extract/extract-block.ll b/llvm/test/tools/llvm-extract/extract-block.ll index 03caf138aa36e..03093eb6bbdc0 100644 --- a/llvm/test/tools/llvm-extract/extract-block.ll +++ b/llvm/test/tools/llvm-extract/extract-block.ll @@ -1,4 +1,6 @@ -; RUN: llvm-extract -S -bb foo:bb4 %s | FileCheck %s +; RUN: llvm-extract -S -bb foo:bb4 %s | FileCheck %s --check-prefixes=CHECK,KILL +; RUN: llvm-extract -S -bb foo:bb4 --replace-with-call %s | FileCheck %s --check-prefixes=CHECK,KEEP + ; CHECK: declare void @bar() define void @bar() { @@ -12,7 +14,11 @@ bb: ret void } -; CHECK: @foo.bb4 +; KEEP-LABEL: define i32 @foo(i32 %arg) { +; KEEP: call void @foo.bb4 + +; KILL-LABEL: define dso_local void @foo.bb4( +; KEEP-LABEL: define internal void @foo.bb4( ; CHECK: call void @bar() ; CHECK: %tmp5 define i32 @foo(i32 %arg) { diff --git a/llvm/test/tools/llvm-extract/extract-blocks-with-groups.ll b/llvm/test/tools/llvm-extract/extract-blocks-with-groups.ll index 057e70008ff96..0c2f989123549 100644 --- a/llvm/test/tools/llvm-extract/extract-blocks-with-groups.ll +++ b/llvm/test/tools/llvm-extract/extract-blocks-with-groups.ll @@ -1,10 +1,19 @@ -; RUN: llvm-extract -bb 'foo:if;then;else' -bb 'bar:bb14;bb20' -S %s | FileCheck %s +; RUN: llvm-extract -bb 'foo:if;then;else' -bb 'bar:bb14;bb20' -S %s | FileCheck %s --check-prefixes=CHECK,KILL +; RUN: llvm-extract -bb 'foo:if;then;else' -bb 'bar:bb14;bb20' --replace-with-call -S %s | FileCheck %s --check-prefixes=CHECK,KEEP ; Extract two groups of basic blocks in two different functions. +; KEEP-LABEL: define i32 @foo(i32 %arg, i32 %arg1) { +; KEEP: call void @foo.if.split( + +; KEEP-LABEL: define i32 @bar(i32 %arg, i32 %arg1) { +; KEEP: %targetBlock = call i1 @bar.bb14( + + ; The first extracted function is the region composed by the ; blocks if, then, and else from foo. -; CHECK: define dso_local void @foo.if.split(i32 %arg1, i32 %arg, ptr %tmp.0.ce.out) { +; KILL-LABEL: define dso_local void @foo.if.split(i32 %arg1, i32 %arg, ptr %tmp.0.ce.out) { +; KEEP-LABEL: define internal void @foo.if.split(i32 %arg1, i32 %arg, ptr %tmp.0.ce.out) { ; CHECK: newFuncRoot: ; CHECK: br label %if.split ; @@ -25,7 +34,7 @@ ; CHECK: %or.cond = and i1 %tmp5, %tmp8 ; CHECK: br i1 %or.cond, label %then, label %else ; -; CHECK: end.split: ; preds = %then, %else +; CHECK: end.split: ; CHECK: %tmp.0.ce = phi i32 [ %tmp13, %then ], [ %tmp25, %else ] ; CHECK: store i32 %tmp.0.ce, ptr %tmp.0.ce.out ; CHECK: br label %end.exitStub @@ -36,7 +45,8 @@ ; The second extracted function is the region composed by the blocks ; bb14 and bb20 from bar. -; CHECK: define dso_local i1 @bar.bb14(i32 %arg1, i32 %arg, ptr %tmp25.out) { +; KILL-LABEL: define dso_local i1 @bar.bb14(i32 %arg1, i32 %arg, ptr %tmp25.out) { +; KEEP-LABEL: define internal i1 @bar.bb14(i32 %arg1, i32 %arg, ptr %tmp25.out) { ; CHECK: newFuncRoot: ; CHECK: br label %bb14 ; @@ -50,12 +60,14 @@ ; CHECK: %tmp24 = sdiv i32 %arg1, 6 ; CHECK: %tmp25 = add nsw i32 %tmp24, %tmp22 ; CHECK: store i32 %tmp25, ptr %tmp25.out -; CHECK: br label %bb30.exitStub +; KILL: br label %bb30.exitStub +; KEEP: br label %bb20.split.exitStub ; ; CHECK: bb26.exitStub: ; preds = %bb14 ; CHECK: ret i1 true ; -; CHECK: bb30.exitStub: ; preds = %bb20 +; KILL: bb30.exitStub: ; preds = %bb20 +; KEEP: bb20.split.exitStub: ; CHECK: ret i1 false ; CHECK: } diff --git a/llvm/tools/llvm-extract/llvm-extract.cpp b/llvm/tools/llvm-extract/llvm-extract.cpp index 5fc9a31ab4ad7..dd49bcce39bfb 100644 --- a/llvm/tools/llvm-extract/llvm-extract.cpp +++ b/llvm/tools/llvm-extract/llvm-extract.cpp @@ -89,13 +89,21 @@ static cl::list ExtractBlocks( "Specify pairs to extract.\n" "Each pair will create a function.\n" "If multiple basic blocks are specified in one pair,\n" - "the first block in the sequence should dominate the rest.\n" + "the first block in the sequence should dominate the rest (Unless " + "using --bb-keep-blocks).\n" "eg:\n" " --bb=f:bb1;bb2 will extract one function with both bb1 and bb2;\n" " --bb=f:bb1 --bb=f:bb2 will extract two functions, one with bb1, one " "with bb2."), cl::value_desc("function:bb1[;bb2...]"), cl::cat(ExtractCat)); +static cl::opt ReplaceWithCall( + "replace-with-call", + cl::desc( + "When extracting blocks from functions, keep the original functions; " + "extracted code is replaced by function call to new function"), + cl::cat(ExtractCat)); + // ExtractAlias - The alias to extract from the module. static cl::list ExtractAliases("alias", cl::desc("Specify alias to extract"), @@ -383,7 +391,9 @@ int main(int argc, char **argv) { PB.crossRegisterProxies(LAM, FAM, CGAM, MAM); ModulePassManager PM; - PM.addPass(BlockExtractorPass(std::move(GroupOfBBs), true)); + PM.addPass(BlockExtractorPass(std::move(GroupOfBBs), + /*EraseFunction=*/!ReplaceWithCall, + /*KeepOldBlocks=*/ReplaceWithCall)); PM.run(*M, MAM); } diff --git a/llvm/unittests/Transforms/Utils/CodeExtractorTest.cpp b/llvm/unittests/Transforms/Utils/CodeExtractorTest.cpp index 80c2a23a95796..3f513dc70ae98 100644 --- a/llvm/unittests/Transforms/Utils/CodeExtractorTest.cpp +++ b/llvm/unittests/Transforms/Utils/CodeExtractorTest.cpp @@ -7,11 +7,12 @@ //===----------------------------------------------------------------------===// #include "llvm/Transforms/Utils/CodeExtractor.h" -#include "llvm/AsmParser/Parser.h" #include "llvm/Analysis/AssumptionCache.h" +#include "llvm/AsmParser/Parser.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Constants.h" #include "llvm/IR/Dominators.h" +#include "llvm/IR/InstIterator.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Module.h" @@ -30,6 +31,13 @@ BasicBlock *getBlockByName(Function *F, StringRef name) { return nullptr; } +Instruction *getInstByName(Function *F, StringRef Name) { + for (Instruction &I : instructions(F)) + if (I.getName() == Name) + return &I; + return nullptr; +} + TEST(CodeExtractor, ExitStub) { LLVMContext Ctx; SMDiagnostic Err; @@ -513,19 +521,28 @@ TEST(CodeExtractor, PartialAggregateArgs) { target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128" target triple = "x86_64-unknown-linux-gnu" - declare void @use(i32) + ; use different types such that an index mismatch will result in a type mismatch during verification. + declare void @use16(i16) + declare void @use32(i32) + declare void @use64(i64) - define void @foo(i32 %a, i32 %b, i32 %c) { + define void @foo(i16 %a, i32 %b, i64 %c) { entry: br label %extract extract: - call void @use(i32 %a) - call void @use(i32 %b) - call void @use(i32 %c) + call void @use16(i16 %a) + call void @use32(i32 %b) + call void @use64(i64 %c) + %d = add i16 21, 21 + %e = add i32 21, 21 + %f = add i64 21, 21 br label %exit exit: + call void @use16(i16 %d) + call void @use32(i32 %e) + call void @use64(i64 %f) ret void } )ir", @@ -544,18 +561,70 @@ TEST(CodeExtractor, PartialAggregateArgs) { BasicBlock *CommonExit = nullptr; CE.findAllocas(CEAC, SinkingCands, HoistingCands, CommonExit); CE.findInputsOutputs(Inputs, Outputs, SinkingCands); - // Exclude the first input from the argument aggregate. - CE.excludeArgFromAggregate(Inputs[0]); + // Exclude the middle input and output from the argument aggregate. + CE.excludeArgFromAggregate(Inputs[1]); + CE.excludeArgFromAggregate(Outputs[1]); Function *Outlined = CE.extractCodeRegion(CEAC, Inputs, Outputs); EXPECT_TRUE(Outlined); - // Expect 2 arguments in the outlined function: the excluded input and the - // struct aggregate for the remaining inputs. - EXPECT_EQ(Outlined->arg_size(), 2U); + // Expect 3 arguments in the outlined function: the excluded input, the + // excluded output, and the struct aggregate for the remaining inputs. + EXPECT_EQ(Outlined->arg_size(), 3U); EXPECT_FALSE(verifyFunction(*Outlined)); EXPECT_FALSE(verifyFunction(*Func)); } +TEST(CodeExtractor, AllocaBlock) { + LLVMContext Ctx; + SMDiagnostic Err; + std::unique_ptr M(parseAssemblyString(R"invalid( + define i32 @foo(i32 %x, i32 %y, i32 %z) { + entry: + br label %allocas + + allocas: + br label %body + + body: + %w = add i32 %x, %y + br label %notExtracted + + notExtracted: + %r = add i32 %w, %x + ret i32 %r + } + )invalid", + Err, Ctx)); + + Function *Func = M->getFunction("foo"); + SmallVector Candidates{getBlockByName(Func, "body")}; + + BasicBlock *AllocaBlock = getBlockByName(Func, "allocas"); + CodeExtractor CE(Candidates, nullptr, true, nullptr, nullptr, nullptr, false, + false, AllocaBlock); + CE.excludeArgFromAggregate(Func->getArg(0)); + CE.excludeArgFromAggregate(getInstByName(Func, "w")); + EXPECT_TRUE(CE.isEligible()); + + CodeExtractorAnalysisCache CEAC(*Func); + SetVector Inputs, Outputs; + Function *Outlined = CE.extractCodeRegion(CEAC, Inputs, Outputs); + EXPECT_TRUE(Outlined); + EXPECT_FALSE(verifyFunction(*Outlined)); + EXPECT_FALSE(verifyFunction(*Func)); + + // The only added allocas may be in the dedicated alloca block. There should + // be one alloca for the struct, and another one for the reload value. + int NumAllocas = 0; + for (Instruction &I : instructions(Func)) { + if (!isa(I)) + continue; + EXPECT_EQ(I.getParent(), AllocaBlock); + NumAllocas += 1; + } + EXPECT_EQ(NumAllocas, 2); +} + /// Regression test to ensure we don't crash trying to set the name of the ptr /// argument TEST(CodeExtractor, PartialAggregateArgs2) {