Skip to content

Commit 852ca95

Browse files
committed
[OpenMP][OMPIRBuilder] Use OMPKinds.def to specify callback metadata
1 parent 54da543 commit 852ca95

File tree

3 files changed

+95
-15
lines changed

3 files changed

+95
-15
lines changed

llvm/include/llvm/Frontend/OpenMP/OMPKinds.def

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1418,3 +1418,26 @@ __OMP_ASSUME_CLAUSE(llvm::StringLiteral("no_parallelism"), false, false, false)
14181418
#undef __OMP_ASSUME_CLAUSE
14191419
#undef OMP_ASSUME_CLAUSE
14201420
///}
1421+
1422+
1423+
/// Callback specification
1424+
///
1425+
///{
1426+
1427+
#ifndef OMP_CALLBACK
1428+
#define OMP_CALLBACK(Enum, VarArgsArePassed, CallbackArgNo, ...)
1429+
#endif
1430+
1431+
#define __OMP_CALLBACK(Name, VarArgsArePassed, CallbackArgNo, ...) \
1432+
OMP_CALLBACK(OMPRTL_##Name, VarArgsArePassed, CallbackArgNo, __VA_ARGS__)
1433+
1434+
__OMP_CALLBACK(__kmpc_fork_call, true, 2, -1, -1)
1435+
__OMP_CALLBACK(__kmpc_fork_call_if, true, 2, -1, -1)
1436+
__OMP_CALLBACK(__kmpc_fork_teams, true, 2, -1, -1)
1437+
__OMP_CALLBACK(__kmpc_omp_task_alloc, true, 5, -1, -1)
1438+
1439+
#undef __OMP_PTR_TYPE
1440+
1441+
#undef __OMP_TYPE
1442+
#undef OMP_CALLBACK
1443+
///}

llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -614,21 +614,27 @@ OpenMPIRBuilder::getOrCreateRuntimeFunction(Module &M, RuntimeFunction FnID) {
614614
#include "llvm/Frontend/OpenMP/OMPKinds.def"
615615
}
616616

617-
// Add information if the runtime function takes a callback function
618-
if (FnID == OMPRTL___kmpc_fork_call || FnID == OMPRTL___kmpc_fork_teams) {
619-
if (!Fn->hasMetadata(LLVMContext::MD_callback)) {
620-
LLVMContext &Ctx = Fn->getContext();
621-
MDBuilder MDB(Ctx);
622-
// Annotate the callback behavior of the runtime function:
623-
// - The callback callee is argument number 2 (microtask).
624-
// - The first two arguments of the callback callee are unknown (-1).
625-
// - All variadic arguments to the runtime function are passed to the
626-
// callback callee.
627-
Fn->addMetadata(
628-
LLVMContext::MD_callback,
629-
*MDNode::get(Ctx, {MDB.createCallbackEncoding(
630-
2, {-1, -1}, /* VarArgsArePassed */ true)}));
631-
}
617+
// Annotate the callback behavior of the runtime function:
618+
// - First the callback callee argument number
619+
// - Then the arguments passed on to the callback (-1 for unknown),
620+
// variadic
621+
// - Finally, whether variadic args are passed on to the callback.
622+
LLVMContext &Ctx = Fn->getContext();
623+
MDBuilder MDB(Ctx);
624+
switch (FnID) {
625+
#define OMP_CALLBACK(Enum, VarArgsArePassed, CallbackArgNo, ...) \
626+
case Enum: { \
627+
if (!Fn->hasMetadata(LLVMContext::MD_callback)) { \
628+
Fn->addMetadata(LLVMContext::MD_callback, \
629+
*MDNode::get(Ctx, {MDB.createCallbackEncoding( \
630+
CallbackArgNo, {__VA_ARGS__}, \
631+
VarArgsArePassed)})); \
632+
} \
633+
break; \
634+
}
635+
#include "llvm/Frontend/OpenMP/OMPKinds.def"
636+
default:
637+
break;
632638
}
633639

634640
LLVM_DEBUG(dbgs() << "Created OpenMP runtime function " << Fn->getName()

llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7702,4 +7702,55 @@ TEST_F(OpenMPIRBuilderTest, splitBB) {
77027702
EXPECT_TRUE(DL == AllocaBB->getTerminator()->getStableDebugLoc());
77037703
}
77047704

7705+
TEST_F(OpenMPIRBuilderTest, createCallbackMetadata) {
7706+
OpenMPIRBuilder OMPBuilder(*M);
7707+
OMPBuilder.initialize();
7708+
7709+
FunctionCallee ForkCall = OMPBuilder.getOrCreateRuntimeFunction(
7710+
*M, llvm::omp::RuntimeFunction::OMPRTL___kmpc_fork_call);
7711+
FunctionCallee ForkCallIf = OMPBuilder.getOrCreateRuntimeFunction(
7712+
*M, llvm::omp::RuntimeFunction::OMPRTL___kmpc_fork_call_if);
7713+
FunctionCallee ForkTeam = OMPBuilder.getOrCreateRuntimeFunction(
7714+
*M, llvm::omp::RuntimeFunction::OMPRTL___kmpc_fork_teams);
7715+
FunctionCallee TaskAlloc = OMPBuilder.getOrCreateRuntimeFunction(
7716+
*M, llvm::omp::RuntimeFunction::OMPRTL___kmpc_omp_task_alloc);
7717+
7718+
M->dump();
7719+
for (auto [FC, ArgNo] : zip(SmallVector<FunctionCallee>(
7720+
{ForkCall, ForkCallIf, ForkTeam, TaskAlloc}),
7721+
SmallVector<unsigned>({2, 2, 2, 5}))) {
7722+
MDNode *CallbackMD =
7723+
cast<Function>(FC.getCallee())->getMetadata(LLVMContext::MD_callback);
7724+
EXPECT_NE(CallbackMD, nullptr);
7725+
unsigned Num = 0;
7726+
CallbackMD->dump();
7727+
M->dump();
7728+
for (const MDOperand &Op : CallbackMD->operands()) {
7729+
Num++;
7730+
MDNode *OpMD = cast<MDNode>(Op.get());
7731+
auto *CBCalleeIdxAsCM = cast<ConstantAsMetadata>(OpMD->getOperand(0));
7732+
uint64_t CBCalleeIdx =
7733+
cast<ConstantInt>(CBCalleeIdxAsCM->getValue())->getZExtValue();
7734+
EXPECT_EQ(CBCalleeIdx, ArgNo);
7735+
7736+
uint64_t Arg0 =
7737+
cast<ConstantInt>(
7738+
cast<ConstantAsMetadata>(OpMD->getOperand(1))->getValue())
7739+
->getZExtValue();
7740+
uint64_t Arg1 =
7741+
cast<ConstantInt>(
7742+
cast<ConstantAsMetadata>(OpMD->getOperand(2))->getValue())
7743+
->getZExtValue();
7744+
uint64_t VarArg =
7745+
cast<ConstantInt>(
7746+
cast<ConstantAsMetadata>(OpMD->getOperand(3))->getValue())
7747+
->getZExtValue();
7748+
EXPECT_EQ(Arg0, -1);
7749+
EXPECT_EQ(Arg1, -1);
7750+
EXPECT_EQ(VarArg, true);
7751+
}
7752+
EXPECT_EQ(Num, 1);
7753+
}
7754+
}
7755+
77057756
} // namespace

0 commit comments

Comments
 (0)