Skip to content

Commit 8d17875

Browse files
authored
[OMPIRBuilder] Added createTeams (#66807)
This patch adds basic support for `omp teams` to the OpenMPIRBuilder. The outlined function after code extraction is called from a wrapper function with appropriate arguments. This wrapper function is passed to the runtime calls. This approach is different from the Clang approach - clang directly emits the runtime call to the outlined function. The outlining utility (OutlineInfo) simply outlines the code and generates a function call to the outlined function. After the function has been generated by the outlining utility, there is no easy way to alter the function arguments without meddling with the outlining itself. Hence the wrapper function approach is taken.
1 parent 541e88d commit 8d17875

File tree

3 files changed

+211
-0
lines changed

3 files changed

+211
-0
lines changed

llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1889,6 +1889,13 @@ class OpenMPIRBuilder {
18891889
BodyGenCallbackTy BodyGenCB,
18901890
FinalizeCallbackTy FiniCB);
18911891

1892+
/// Generator for `#omp teams`
1893+
///
1894+
/// \param Loc The location where the teams construct was encountered.
1895+
/// \param BodyGenCB Callback that will generate the region code.
1896+
InsertPointTy createTeams(const LocationDescription &Loc,
1897+
BodyGenCallbackTy BodyGenCB);
1898+
18921899
/// Generate conditional branch and relevant BasicBlocks through which private
18931900
/// threads copy the 'copyin' variables from Master copy to threadprivate
18941901
/// copies.

llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5735,6 +5735,129 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createAtomicCompare(
57355735
return Builder.saveIP();
57365736
}
57375737

5738+
OpenMPIRBuilder::InsertPointTy
5739+
OpenMPIRBuilder::createTeams(const LocationDescription &Loc,
5740+
BodyGenCallbackTy BodyGenCB) {
5741+
if (!updateToLocation(Loc))
5742+
return InsertPointTy();
5743+
5744+
uint32_t SrcLocStrSize;
5745+
Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
5746+
Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
5747+
Function *CurrentFunction = Builder.GetInsertBlock()->getParent();
5748+
5749+
// Outer allocation basicblock is the entry block of the current function.
5750+
BasicBlock &OuterAllocaBB = CurrentFunction->getEntryBlock();
5751+
if (&OuterAllocaBB == Builder.GetInsertBlock()) {
5752+
BasicBlock *BodyBB = splitBB(Builder, /*CreateBranch=*/true, "teams.entry");
5753+
Builder.SetInsertPoint(BodyBB, BodyBB->begin());
5754+
}
5755+
5756+
// The current basic block is split into four basic blocks. After outlining,
5757+
// they will be mapped as follows:
5758+
// ```
5759+
// def current_fn() {
5760+
// current_basic_block:
5761+
// br label %teams.exit
5762+
// teams.exit:
5763+
// ; instructions after teams
5764+
// }
5765+
//
5766+
// def outlined_fn() {
5767+
// teams.alloca:
5768+
// br label %teams.body
5769+
// teams.body:
5770+
// ; instructions within teams body
5771+
// }
5772+
// ```
5773+
BasicBlock *ExitBB = splitBB(Builder, /*CreateBranch=*/true, "teams.exit");
5774+
BasicBlock *BodyBB = splitBB(Builder, /*CreateBranch=*/true, "teams.body");
5775+
BasicBlock *AllocaBB =
5776+
splitBB(Builder, /*CreateBranch=*/true, "teams.alloca");
5777+
5778+
OutlineInfo OI;
5779+
OI.EntryBB = AllocaBB;
5780+
OI.ExitBB = ExitBB;
5781+
OI.OuterAllocaBB = &OuterAllocaBB;
5782+
OI.PostOutlineCB = [this, Ident](Function &OutlinedFn) {
5783+
// The input IR here looks like the following-
5784+
// ```
5785+
// func @current_fn() {
5786+
// outlined_fn(%args)
5787+
// }
5788+
// func @outlined_fn(%args) { ... }
5789+
// ```
5790+
//
5791+
// This is changed to the following-
5792+
//
5793+
// ```
5794+
// func @current_fn() {
5795+
// runtime_call(..., wrapper_fn, ...)
5796+
// }
5797+
// func @wrapper_fn(..., %args) {
5798+
// outlined_fn(%args)
5799+
// }
5800+
// func @outlined_fn(%args) { ... }
5801+
// ```
5802+
5803+
// The stale call instruction will be replaced with a new call instruction
5804+
// for runtime call with a wrapper function.
5805+
5806+
assert(OutlinedFn.getNumUses() == 1 &&
5807+
"there must be a single user for the outlined function");
5808+
CallInst *StaleCI = cast<CallInst>(OutlinedFn.user_back());
5809+
5810+
// Create the wrapper function.
5811+
SmallVector<Type *> WrapperArgTys{Builder.getPtrTy(), Builder.getPtrTy()};
5812+
for (auto &Arg : OutlinedFn.args())
5813+
WrapperArgTys.push_back(Arg.getType());
5814+
FunctionCallee WrapperFuncVal = M.getOrInsertFunction(
5815+
(Twine(OutlinedFn.getName()) + ".teams").str(),
5816+
FunctionType::get(Builder.getVoidTy(), WrapperArgTys, false));
5817+
Function *WrapperFunc = dyn_cast<Function>(WrapperFuncVal.getCallee());
5818+
WrapperFunc->getArg(0)->setName("global_tid");
5819+
WrapperFunc->getArg(1)->setName("bound_tid");
5820+
if (WrapperFunc->arg_size() > 2)
5821+
WrapperFunc->getArg(2)->setName("data");
5822+
5823+
// Emit the body of the wrapper function - just a call to outlined function
5824+
// and return statement.
5825+
BasicBlock *WrapperEntryBB =
5826+
BasicBlock::Create(M.getContext(), "entrybb", WrapperFunc);
5827+
Builder.SetInsertPoint(WrapperEntryBB);
5828+
SmallVector<Value *> Args;
5829+
for (size_t ArgIndex = 2; ArgIndex < WrapperFunc->arg_size(); ArgIndex++)
5830+
Args.push_back(WrapperFunc->getArg(ArgIndex));
5831+
Builder.CreateCall(&OutlinedFn, Args);
5832+
Builder.CreateRetVoid();
5833+
5834+
OutlinedFn.addFnAttr(Attribute::AttrKind::AlwaysInline);
5835+
5836+
// Call to the runtime function for teams in the current function.
5837+
assert(StaleCI && "Error while outlining - no CallInst user found for the "
5838+
"outlined function.");
5839+
Builder.SetInsertPoint(StaleCI);
5840+
Args = {Ident, Builder.getInt32(StaleCI->arg_size()), WrapperFunc};
5841+
for (Use &Arg : StaleCI->args())
5842+
Args.push_back(Arg);
5843+
Builder.CreateCall(getOrCreateRuntimeFunctionPtr(
5844+
omp::RuntimeFunction::OMPRTL___kmpc_fork_teams),
5845+
Args);
5846+
StaleCI->eraseFromParent();
5847+
};
5848+
5849+
addOutlineInfo(std::move(OI));
5850+
5851+
// Generate the body of teams.
5852+
InsertPointTy AllocaIP(AllocaBB, AllocaBB->begin());
5853+
InsertPointTy CodeGenIP(BodyBB, BodyBB->begin());
5854+
BodyGenCB(AllocaIP, CodeGenIP);
5855+
5856+
Builder.SetInsertPoint(ExitBB, ExitBB->begin());
5857+
5858+
return Builder.saveIP();
5859+
}
5860+
57385861
GlobalVariable *
57395862
OpenMPIRBuilder::createOffloadMapnames(SmallVectorImpl<llvm::Constant *> &Names,
57405863
std::string VarName) {

llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4001,6 +4001,87 @@ TEST_F(OpenMPIRBuilderTest, OMPAtomicCompareCapture) {
40014001
EXPECT_FALSE(verifyModule(*M, &errs()));
40024002
}
40034003

4004+
TEST_F(OpenMPIRBuilderTest, CreateTeams) {
4005+
using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
4006+
OpenMPIRBuilder OMPBuilder(*M);
4007+
OMPBuilder.initialize();
4008+
F->setName("func");
4009+
IRBuilder<> Builder(BB);
4010+
4011+
AllocaInst *ValPtr32 = Builder.CreateAlloca(Builder.getInt32Ty());
4012+
AllocaInst *ValPtr128 = Builder.CreateAlloca(Builder.getInt128Ty());
4013+
Value *Val128 = Builder.CreateLoad(Builder.getInt128Ty(), ValPtr128, "load");
4014+
4015+
auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
4016+
Builder.restoreIP(AllocaIP);
4017+
AllocaInst *Local128 = Builder.CreateAlloca(Builder.getInt128Ty(), nullptr,
4018+
"bodygen.alloca128");
4019+
4020+
Builder.restoreIP(CodeGenIP);
4021+
// Loading and storing captured pointer and values
4022+
Builder.CreateStore(Val128, Local128);
4023+
Value *Val32 = Builder.CreateLoad(ValPtr32->getAllocatedType(), ValPtr32,
4024+
"bodygen.load32");
4025+
4026+
LoadInst *PrivLoad128 = Builder.CreateLoad(
4027+
Local128->getAllocatedType(), Local128, "bodygen.local.load128");
4028+
Value *Cmp = Builder.CreateICmpNE(
4029+
Val32, Builder.CreateTrunc(PrivLoad128, Val32->getType()));
4030+
Instruction *ThenTerm, *ElseTerm;
4031+
SplitBlockAndInsertIfThenElse(Cmp, CodeGenIP.getBlock()->getTerminator(),
4032+
&ThenTerm, &ElseTerm);
4033+
};
4034+
4035+
OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL});
4036+
Builder.restoreIP(OMPBuilder.createTeams(Builder, BodyGenCB));
4037+
4038+
OMPBuilder.finalize();
4039+
Builder.CreateRetVoid();
4040+
4041+
EXPECT_FALSE(verifyModule(*M, &errs()));
4042+
4043+
CallInst *TeamsForkCall = dyn_cast<CallInst>(
4044+
OMPBuilder.getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_fork_teams)
4045+
->user_back());
4046+
4047+
// Verify the Ident argument
4048+
GlobalVariable *Ident = cast<GlobalVariable>(TeamsForkCall->getArgOperand(0));
4049+
ASSERT_NE(Ident, nullptr);
4050+
EXPECT_TRUE(Ident->hasInitializer());
4051+
Constant *Initializer = Ident->getInitializer();
4052+
GlobalVariable *SrcStrGlob =
4053+
cast<GlobalVariable>(Initializer->getOperand(4)->stripPointerCasts());
4054+
ASSERT_NE(SrcStrGlob, nullptr);
4055+
ConstantDataArray *SrcSrc =
4056+
dyn_cast<ConstantDataArray>(SrcStrGlob->getInitializer());
4057+
ASSERT_NE(SrcSrc, nullptr);
4058+
4059+
// Verify the outlined function signature.
4060+
Function *WrapperFn =
4061+
dyn_cast<Function>(TeamsForkCall->getArgOperand(2)->stripPointerCasts());
4062+
ASSERT_NE(WrapperFn, nullptr);
4063+
EXPECT_FALSE(WrapperFn->isDeclaration());
4064+
EXPECT_TRUE(WrapperFn->arg_size() >= 3);
4065+
EXPECT_EQ(WrapperFn->getArg(0)->getType(), Builder.getPtrTy()); // global_tid
4066+
EXPECT_EQ(WrapperFn->getArg(1)->getType(), Builder.getPtrTy()); // bound_tid
4067+
EXPECT_EQ(WrapperFn->getArg(2)->getType(),
4068+
Builder.getPtrTy()); // captured args
4069+
4070+
// Check for TruncInst and ICmpInst in the outlined function.
4071+
inst_range Instructions = instructions(WrapperFn);
4072+
auto OutlinedFnInst = find_if(
4073+
Instructions, [](Instruction &Inst) { return isa<CallInst>(&Inst); });
4074+
ASSERT_NE(OutlinedFnInst, Instructions.end());
4075+
CallInst *OutlinedFnCI = dyn_cast<CallInst>(&*OutlinedFnInst);
4076+
ASSERT_NE(OutlinedFnCI, nullptr);
4077+
Function *OutlinedFn = OutlinedFnCI->getCalledFunction();
4078+
4079+
EXPECT_TRUE(any_of(instructions(OutlinedFn),
4080+
[](Instruction &inst) { return isa<TruncInst>(&inst); }));
4081+
EXPECT_TRUE(any_of(instructions(OutlinedFn),
4082+
[](Instruction &inst) { return isa<ICmpInst>(&inst); }));
4083+
}
4084+
40044085
/// Returns the single instruction of InstTy type in BB that uses the value V.
40054086
/// If there is more than one such instruction, returns null.
40064087
template <typename InstTy>

0 commit comments

Comments
 (0)