Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 13 additions & 11 deletions llvm/lib/Transforms/Utils/CodeExtractor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -823,17 +823,22 @@ Function *CodeExtractor::constructFunction(const ValueSet &inputs,

std::vector<Type *> ParamTy;
std::vector<Type *> AggParamTy;
std::vector<std::tuple<unsigned, Value *>> NumberedInputs;
std::vector<std::tuple<unsigned, Value *>> NumberedOutputs;
ValueSet StructValues;
const DataLayout &DL = M->getDataLayout();

// Add the types of the input values to the function's argument list
unsigned ArgNum = 0;
for (Value *value : inputs) {
LLVM_DEBUG(dbgs() << "value used in func: " << *value << "\n");
if (AggregateArgs && !ExcludeArgsFromAggregate.contains(value)) {
AggParamTy.push_back(value->getType());
StructValues.insert(value);
} else
} else {
ParamTy.push_back(value->getType());
NumberedInputs.emplace_back(ArgNum++, value);
}
}

// Add the types of the output values to the function's argument list.
Expand All @@ -842,9 +847,11 @@ Function *CodeExtractor::constructFunction(const ValueSet &inputs,
if (AggregateArgs && !ExcludeArgsFromAggregate.contains(output)) {
AggParamTy.push_back(output->getType());
StructValues.insert(output);
} else
} else {
ParamTy.push_back(
PointerType::get(output->getType(), DL.getAllocaAddrSpace()));
NumberedOutputs.emplace_back(ArgNum++, output);
}
}

assert(
Expand Down Expand Up @@ -1053,15 +1060,10 @@ Function *CodeExtractor::constructFunction(const ValueSet &inputs,
}

// Set names for input and output arguments.
if (NumScalarParams) {
ScalarAI = newFunction->arg_begin();
for (unsigned i = 0, e = inputs.size(); i != e; ++i, ++ScalarAI)
if (!StructValues.contains(inputs[i]))
ScalarAI->setName(inputs[i]->getName());
for (unsigned i = 0, e = outputs.size(); i != e; ++i, ++ScalarAI)
if (!StructValues.contains(outputs[i]))
ScalarAI->setName(outputs[i]->getName() + ".out");
}
for (auto [i, argVal] : NumberedInputs)
newFunction->getArg(i)->setName(argVal->getName());
for (auto [i, argVal] : NumberedOutputs)
newFunction->getArg(i)->setName(argVal->getName() + ".out");

// Rewrite branches to basic blocks outside of the loop to new dummy blocks
// within the new function. This must be done before we lose track of which
Expand Down
47 changes: 47 additions & 0 deletions llvm/unittests/Transforms/Utils/CodeExtractorTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -556,6 +556,53 @@ TEST(CodeExtractor, PartialAggregateArgs) {
EXPECT_FALSE(verifyFunction(*Func));
}

/// Regression test to ensure we don't crash trying to set the name of the ptr
/// argument
TEST(CodeExtractor, PartialAggregateArgs2) {
LLVMContext Ctx;
SMDiagnostic Err;
std::unique_ptr<Module> M(parseAssemblyString(R"ir(
declare void @usei(i32)
declare void @usep(ptr)

define void @foo(i32 %a, i32 %b, ptr %p) {
entry:
br label %extract

extract:
call void @usei(i32 %a)
call void @usei(i32 %b)
call void @usep(ptr %p)
br label %exit

exit:
ret void
}
)ir",
Err, Ctx));

Function *Func = M->getFunction("foo");
SmallVector<BasicBlock *, 1> Blocks{getBlockByName(Func, "extract")};

// Create the CodeExtractor with arguments aggregation enabled.
CodeExtractor CE(Blocks, /* DominatorTree */ nullptr,
/* AggregateArgs */ true);
EXPECT_TRUE(CE.isEligible());

CodeExtractorAnalysisCache CEAC(*Func);
SetVector<Value *> Inputs, Outputs, SinkingCands, HoistingCands;
BasicBlock *CommonExit = nullptr;
CE.findAllocas(CEAC, SinkingCands, HoistingCands, CommonExit);
CE.findInputsOutputs(Inputs, Outputs, SinkingCands);
// Exclude the last input from the argument aggregate.
CE.excludeArgFromAggregate(Inputs[2]);

Function *Outlined = CE.extractCodeRegion(CEAC, Inputs, Outputs);
EXPECT_TRUE(Outlined);
EXPECT_FALSE(verifyFunction(*Outlined));
EXPECT_FALSE(verifyFunction(*Func));
}

TEST(CodeExtractor, OpenMPAggregateArgs) {
LLVMContext Ctx;
SMDiagnostic Err;
Expand Down
Loading