diff --git a/llvm/lib/Transforms/Utils/CodeExtractor.cpp b/llvm/lib/Transforms/Utils/CodeExtractor.cpp index ed4ad15e5ab69..fa467cc72bd02 100644 --- a/llvm/lib/Transforms/Utils/CodeExtractor.cpp +++ b/llvm/lib/Transforms/Utils/CodeExtractor.cpp @@ -823,17 +823,22 @@ Function *CodeExtractor::constructFunction(const ValueSet &inputs, std::vector ParamTy; std::vector AggParamTy; + std::vector> NumberedInputs; + std::vector> NumberedOutputs; ValueSet StructValues; const DataLayout &DL = M->getDataLayout(); // Add the types of the input values to the function's argument list + unsigned ArgNum = 0; for (Value *value : inputs) { LLVM_DEBUG(dbgs() << "value used in func: " << *value << "\n"); if (AggregateArgs && !ExcludeArgsFromAggregate.contains(value)) { AggParamTy.push_back(value->getType()); StructValues.insert(value); - } else + } else { ParamTy.push_back(value->getType()); + NumberedInputs.emplace_back(ArgNum++, value); + } } // Add the types of the output values to the function's argument list. @@ -842,9 +847,11 @@ Function *CodeExtractor::constructFunction(const ValueSet &inputs, if (AggregateArgs && !ExcludeArgsFromAggregate.contains(output)) { AggParamTy.push_back(output->getType()); StructValues.insert(output); - } else + } else { ParamTy.push_back( PointerType::get(output->getType(), DL.getAllocaAddrSpace())); + NumberedOutputs.emplace_back(ArgNum++, output); + } } assert( @@ -1053,15 +1060,10 @@ Function *CodeExtractor::constructFunction(const ValueSet &inputs, } // Set names for input and output arguments. - if (NumScalarParams) { - ScalarAI = newFunction->arg_begin(); - for (unsigned i = 0, e = inputs.size(); i != e; ++i, ++ScalarAI) - if (!StructValues.contains(inputs[i])) - ScalarAI->setName(inputs[i]->getName()); - for (unsigned i = 0, e = outputs.size(); i != e; ++i, ++ScalarAI) - if (!StructValues.contains(outputs[i])) - ScalarAI->setName(outputs[i]->getName() + ".out"); - } + for (auto [i, argVal] : NumberedInputs) + newFunction->getArg(i)->setName(argVal->getName()); + for (auto [i, argVal] : NumberedOutputs) + newFunction->getArg(i)->setName(argVal->getName() + ".out"); // Rewrite branches to basic blocks outside of the loop to new dummy blocks // within the new function. This must be done before we lose track of which diff --git a/llvm/unittests/Transforms/Utils/CodeExtractorTest.cpp b/llvm/unittests/Transforms/Utils/CodeExtractorTest.cpp index 046010716862f..80c2a23a95796 100644 --- a/llvm/unittests/Transforms/Utils/CodeExtractorTest.cpp +++ b/llvm/unittests/Transforms/Utils/CodeExtractorTest.cpp @@ -556,6 +556,53 @@ TEST(CodeExtractor, PartialAggregateArgs) { EXPECT_FALSE(verifyFunction(*Func)); } +/// Regression test to ensure we don't crash trying to set the name of the ptr +/// argument +TEST(CodeExtractor, PartialAggregateArgs2) { + LLVMContext Ctx; + SMDiagnostic Err; + std::unique_ptr M(parseAssemblyString(R"ir( + declare void @usei(i32) + declare void @usep(ptr) + + define void @foo(i32 %a, i32 %b, ptr %p) { + entry: + br label %extract + + extract: + call void @usei(i32 %a) + call void @usei(i32 %b) + call void @usep(ptr %p) + br label %exit + + exit: + ret void + } + )ir", + Err, Ctx)); + + Function *Func = M->getFunction("foo"); + SmallVector Blocks{getBlockByName(Func, "extract")}; + + // Create the CodeExtractor with arguments aggregation enabled. + CodeExtractor CE(Blocks, /* DominatorTree */ nullptr, + /* AggregateArgs */ true); + EXPECT_TRUE(CE.isEligible()); + + CodeExtractorAnalysisCache CEAC(*Func); + SetVector Inputs, Outputs, SinkingCands, HoistingCands; + BasicBlock *CommonExit = nullptr; + CE.findAllocas(CEAC, SinkingCands, HoistingCands, CommonExit); + CE.findInputsOutputs(Inputs, Outputs, SinkingCands); + // Exclude the last input from the argument aggregate. + CE.excludeArgFromAggregate(Inputs[2]); + + Function *Outlined = CE.extractCodeRegion(CEAC, Inputs, Outputs); + EXPECT_TRUE(Outlined); + EXPECT_FALSE(verifyFunction(*Outlined)); + EXPECT_FALSE(verifyFunction(*Func)); +} + TEST(CodeExtractor, OpenMPAggregateArgs) { LLVMContext Ctx; SMDiagnostic Err;