diff --git a/llvm/include/llvm/Transforms/Utils/CodeExtractor.h b/llvm/include/llvm/Transforms/Utils/CodeExtractor.h index 826347e79f719..98ac4f726e51a 100644 --- a/llvm/include/llvm/Transforms/Utils/CodeExtractor.h +++ b/llvm/include/llvm/Transforms/Utils/CodeExtractor.h @@ -102,7 +102,6 @@ class CodeExtractorAnalysisCache { // Bits of intermediate state computed at various phases of extraction. SetVector Blocks; unsigned NumExitBlocks = std::numeric_limits::max(); - Type *RetTy; // Mapping from the original exit blocks, to the new blocks inside // the function. @@ -238,6 +237,10 @@ class CodeExtractorAnalysisCache { getLifetimeMarkers(const CodeExtractorAnalysisCache &CEAC, Instruction *Addr, BasicBlock *ExitBlock) const; + /// Return the type used for the return code of the extracted function to + /// indicate which exit block to jump to. + Type *getSwitchType(); + void severSplitPHINodesOfEntry(BasicBlock *&Header); void severSplitPHINodesOfExits(const SetVector &Exits); void splitReturnBlocks(); diff --git a/llvm/lib/Transforms/Utils/CodeExtractor.cpp b/llvm/lib/Transforms/Utils/CodeExtractor.cpp index fa467cc72bd02..bbe4eddd39bf5 100644 --- a/llvm/lib/Transforms/Utils/CodeExtractor.cpp +++ b/llvm/lib/Transforms/Utils/CodeExtractor.cpp @@ -813,14 +813,6 @@ Function *CodeExtractor::constructFunction(const ValueSet &inputs, LLVM_DEBUG(dbgs() << "inputs: " << inputs.size() << "\n"); LLVM_DEBUG(dbgs() << "outputs: " << outputs.size() << "\n"); - // This function returns unsigned, outputs will go back by reference. - switch (NumExitBlocks) { - case 0: - case 1: RetTy = Type::getVoidTy(header->getContext()); break; - case 2: RetTy = Type::getInt1Ty(header->getContext()); break; - default: RetTy = Type::getInt16Ty(header->getContext()); break; - } - std::vector ParamTy; std::vector AggParamTy; std::vector> NumberedInputs; @@ -870,6 +862,7 @@ Function *CodeExtractor::constructFunction(const ValueSet &inputs, StructTy, ArgsInZeroAddressSpace ? 0 : DL.getAllocaAddrSpace())); } + Type *RetTy = getSwitchType(); LLVM_DEBUG({ dbgs() << "Function type: " << *RetTy << " f("; for (Type *i : ParamTy) @@ -1080,6 +1073,22 @@ Function *CodeExtractor::constructFunction(const ValueSet &inputs, return newFunction; } +Type *CodeExtractor::getSwitchType() { + LLVMContext &Context = Blocks.front()->getContext(); + + assert(NumExitBlocks < 0xffff && "too many exit blocks for switch"); + switch (NumExitBlocks) { + case 0: + case 1: + return Type::getVoidTy(Context); + case 2: + // Conditional branch, return a bool + return Type::getInt1Ty(Context); + default: + return Type::getInt16Ty(Context); + } +} + /// Erase lifetime.start markers which reference inputs to the extraction /// region, and insert the referenced memory into \p LifetimesStart. ///