Skip to content

Commit 6322239

Browse files
committed
[OpenMP][OMPIRBuilder] Support parallel in Generic kernels
This patch introduces codegen logic to produce a wrapper function argument for the `__kmpc_parallel_51` DeviceRTL function needed to handle arguments passed using device shared memory in Generic mode.
1 parent 06a2570 commit 6322239

File tree

2 files changed

+116
-9
lines changed

2 files changed

+116
-9
lines changed

llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp

Lines changed: 94 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1334,6 +1334,86 @@ Error OpenMPIRBuilder::emitCancelationCheckImpl(
13341334
return Error::success();
13351335
}
13361336

1337+
// Create wrapper function used to gather the outlined function's argument
1338+
// structure from a shared buffer and to forward them to it when running in
1339+
// Generic mode.
1340+
//
1341+
// The outlined function is expected to receive 2 integer arguments followed by
1342+
// an optional pointer argument to an argument structure holding the rest.
1343+
static Function *createTargetParallelWrapper(OpenMPIRBuilder *OMPIRBuilder,
1344+
Function &OutlinedFn) {
1345+
size_t NumArgs = OutlinedFn.arg_size();
1346+
assert((NumArgs == 2 || NumArgs == 3) &&
1347+
"expected a 2-3 argument parallel outlined function");
1348+
bool UseArgStruct = NumArgs == 3;
1349+
1350+
IRBuilder<> &Builder = OMPIRBuilder->Builder;
1351+
IRBuilder<>::InsertPointGuard IPG(Builder);
1352+
auto *FnTy = FunctionType::get(Builder.getVoidTy(),
1353+
{Builder.getInt16Ty(), Builder.getInt32Ty()},
1354+
/*isVarArg=*/false);
1355+
auto *WrapperFn =
1356+
Function::Create(FnTy, GlobalValue::InternalLinkage,
1357+
OutlinedFn.getName() + ".wrapper", OMPIRBuilder->M);
1358+
1359+
WrapperFn->addParamAttr(0, Attribute::NoUndef);
1360+
WrapperFn->addParamAttr(0, Attribute::ZExt);
1361+
WrapperFn->addParamAttr(1, Attribute::NoUndef);
1362+
1363+
BasicBlock *EntryBB =
1364+
BasicBlock::Create(OMPIRBuilder->M.getContext(), "entry", WrapperFn);
1365+
Builder.SetInsertPoint(EntryBB);
1366+
1367+
// Allocation.
1368+
Value *AddrAlloca = Builder.CreateAlloca(Builder.getInt32Ty(),
1369+
/*ArraySize=*/nullptr, "addr");
1370+
AddrAlloca = Builder.CreatePointerBitCastOrAddrSpaceCast(
1371+
AddrAlloca, Builder.getPtrTy(/*AddrSpace=*/0),
1372+
AddrAlloca->getName() + ".ascast");
1373+
1374+
Value *ZeroAlloca = Builder.CreateAlloca(Builder.getInt32Ty(),
1375+
/*ArraySize=*/nullptr, "zero");
1376+
ZeroAlloca = Builder.CreatePointerBitCastOrAddrSpaceCast(
1377+
ZeroAlloca, Builder.getPtrTy(/*AddrSpace=*/0),
1378+
ZeroAlloca->getName() + ".ascast");
1379+
1380+
Value *ArgsAlloca = nullptr;
1381+
if (UseArgStruct) {
1382+
ArgsAlloca = Builder.CreateAlloca(Builder.getPtrTy(),
1383+
/*ArraySize=*/nullptr, "global_args");
1384+
ArgsAlloca = Builder.CreatePointerBitCastOrAddrSpaceCast(
1385+
ArgsAlloca, Builder.getPtrTy(/*AddrSpace=*/0),
1386+
ArgsAlloca->getName() + ".ascast");
1387+
}
1388+
1389+
// Initialization.
1390+
Builder.CreateStore(WrapperFn->getArg(1), AddrAlloca);
1391+
Builder.CreateStore(Builder.getInt32(0), ZeroAlloca);
1392+
if (UseArgStruct) {
1393+
Builder.CreateCall(
1394+
OMPIRBuilder->getOrCreateRuntimeFunctionPtr(
1395+
llvm::omp::RuntimeFunction::OMPRTL___kmpc_get_shared_variables),
1396+
{ArgsAlloca});
1397+
}
1398+
1399+
SmallVector<Value *, 3> Args{AddrAlloca, ZeroAlloca};
1400+
1401+
// Load structArg from global_args.
1402+
if (UseArgStruct) {
1403+
Value *StructArg = Builder.CreateLoad(Builder.getPtrTy(), ArgsAlloca);
1404+
StructArg = Builder.CreateInBoundsGEP(Builder.getPtrTy(), StructArg,
1405+
{Builder.getInt64(0)});
1406+
StructArg = Builder.CreateLoad(Builder.getPtrTy(), StructArg, "structArg");
1407+
Args.push_back(StructArg);
1408+
}
1409+
1410+
// Call the outlined function holding the parallel body.
1411+
Builder.CreateCall(&OutlinedFn, Args);
1412+
Builder.CreateRetVoid();
1413+
1414+
return WrapperFn;
1415+
}
1416+
13371417
// Callback used to create OpenMP runtime calls to support
13381418
// omp parallel clause for the device.
13391419
// We need to use this callback to replace call to the OutlinedFn in OuterFn
@@ -1343,6 +1423,10 @@ static void targetParallelCallback(
13431423
BasicBlock *OuterAllocaBB, Value *Ident, Value *IfCondition,
13441424
Value *NumThreads, Instruction *PrivTID, AllocaInst *PrivTIDAddr,
13451425
Value *ThreadID, const SmallVector<Instruction *, 4> &ToBeDeleted) {
1426+
assert(OutlinedFn.arg_size() >= 2 &&
1427+
"Expected at least tid and bounded tid as arguments");
1428+
unsigned NumCapturedVars = OutlinedFn.arg_size() - /* tid & bounded tid */ 2;
1429+
13461430
// Add some known attributes.
13471431
IRBuilder<> &Builder = OMPIRBuilder->Builder;
13481432
OutlinedFn.addParamAttr(0, Attribute::NoAlias);
@@ -1351,17 +1435,12 @@ static void targetParallelCallback(
13511435
OutlinedFn.addParamAttr(1, Attribute::NoUndef);
13521436
OutlinedFn.addFnAttr(Attribute::NoUnwind);
13531437

1354-
assert(OutlinedFn.arg_size() >= 2 &&
1355-
"Expected at least tid and bounded tid as arguments");
1356-
unsigned NumCapturedVars = OutlinedFn.arg_size() - /* tid & bounded tid */ 2;
1357-
13581438
CallInst *CI = cast<CallInst>(OutlinedFn.user_back());
13591439
assert(CI && "Expected call instruction to outlined function");
13601440
CI->getParent()->setName("omp_parallel");
13611441

13621442
Builder.SetInsertPoint(CI);
13631443
Type *PtrTy = OMPIRBuilder->VoidPtr;
1364-
Value *NullPtrValue = Constant::getNullValue(PtrTy);
13651444

13661445
// Add alloca for kernel args
13671446
OpenMPIRBuilder ::InsertPointTy CurrentIP = Builder.saveIP();
@@ -1387,6 +1466,15 @@ static void targetParallelCallback(
13871466
IfCondition ? Builder.CreateSExtOrTrunc(IfCondition, OMPIRBuilder->Int32)
13881467
: Builder.getInt32(1);
13891468

1469+
// If this is not a Generic kernel, we can skip generating the wrapper.
1470+
std::optional<omp::OMPTgtExecModeFlags> ExecMode =
1471+
getTargetKernelExecMode(*OuterFn);
1472+
Value *WrapperFn;
1473+
if (ExecMode && *ExecMode & OMP_TGT_EXEC_MODE_GENERIC)
1474+
WrapperFn = createTargetParallelWrapper(OMPIRBuilder, OutlinedFn);
1475+
else
1476+
WrapperFn = Constant::getNullValue(PtrTy);
1477+
13901478
// Build kmpc_parallel_51 call
13911479
Value *Parallel51CallArgs[] = {
13921480
/* identifier*/ Ident,
@@ -1395,7 +1483,7 @@ static void targetParallelCallback(
13951483
/* number of threads */ NumThreads ? NumThreads : Builder.getInt32(-1),
13961484
/* Proc bind */ Builder.getInt32(-1),
13971485
/* outlined function */ &OutlinedFn,
1398-
/* wrapper function */ NullPtrValue,
1486+
/* wrapper function */ WrapperFn,
13991487
/* arguments of the outlined funciton*/ Args,
14001488
/* number of arguments */ Builder.getInt64(NumCapturedVars)};
14011489

mlir/test/Target/LLVMIR/omptarget-parallel-llvm.mlir

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"dlti.alloca_memo
6969
// CHECK: store ptr %[[TMP6]], ptr %[[GEP_]], align 8
7070
// CHECK: %[[TMP7:.*]] = getelementptr inbounds [1 x ptr], ptr %[[TMP2]], i64 0, i64 0
7171
// CHECK: store ptr %[[STRUCTARG]], ptr %[[TMP7]], align 8
72-
// CHECK: call void @__kmpc_parallel_51(ptr addrspacecast (ptr addrspace(1) @[[GLOB1]] to ptr), i32 %[[OMP_GLOBAL_THREAD_NUM]], i32 1, i32 -1, i32 -1, ptr @[[FUNC1:.*]], ptr null, ptr %[[TMP2]], i64 1)
72+
// CHECK: call void @__kmpc_parallel_51(ptr addrspacecast (ptr addrspace(1) @[[GLOB1]] to ptr), i32 %[[OMP_GLOBAL_THREAD_NUM]], i32 1, i32 -1, i32 -1, ptr @[[FUNC1:.*]], ptr @[[FUNC1_WRAPPER:.*]], ptr %[[TMP2]], i64 1)
7373
// CHECK: call void @__kmpc_free_shared(ptr %[[STRUCTARG]], i64 8)
7474
// CHECK: call void @__kmpc_target_deinit()
7575

@@ -84,7 +84,7 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"dlti.alloca_memo
8484
// CHECK: call void @__kmpc_parallel_51(ptr addrspacecast (
8585
// CHECK-SAME: ptr addrspace(1) @[[NUM_THREADS_GLOB:[0-9]+]] to ptr),
8686
// CHECK-SAME: i32 [[NUM_THREADS_TMP0:%.*]], i32 1, i32 156,
87-
// CHECK-SAME: i32 -1, ptr [[FUNC_NUM_THREADS1:@.*]], ptr null, ptr [[NUM_THREADS_TMP1:%.*]], i64 1)
87+
// CHECK-SAME: i32 -1, ptr @[[FUNC_NUM_THREADS1:.*]], ptr @[[FUNC2_WRAPPER:.*]], ptr [[NUM_THREADS_TMP1:%.*]], i64 1)
8888

8989
// One of the arguments of kmpc_parallel_51 function is responsible for handling if clause
9090
// of omp parallel construct for target region. If this argument is nonzero,
@@ -105,4 +105,23 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"dlti.alloca_memo
105105
// CHECK: call void @__kmpc_parallel_51(ptr addrspacecast (
106106
// CHECK-SAME: ptr addrspace(1) {{.*}} to ptr),
107107
// CHECK-SAME: i32 {{.*}}, i32 %[[IFCOND_TMP4]], i32 -1,
108-
// CHECK-SAME: i32 -1, ptr {{.*}}, ptr null, ptr {{.*}}, i64 1)
108+
// CHECK-SAME: i32 -1, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, i64 1)
109+
110+
// CHECK: define internal void @[[FUNC1_WRAPPER]](i16 noundef zeroext %{{.*}}, i32 noundef %[[ADDR:.*]])
111+
// CHECK: %[[ADDR_ALLOCA:.*]] = alloca i32, align 4, addrspace(5)
112+
// CHECK: %[[ADDR_ASCAST:.*]] = addrspacecast ptr addrspace(5) %[[ADDR_ALLOCA]] to ptr
113+
// CHECK: %[[ZERO_ALLOCA:.*]] = alloca i32, align 4, addrspace(5)
114+
// CHECK: %[[ZERO_ASCAST:.*]] = addrspacecast ptr addrspace(5) %[[ZERO_ALLOCA]] to ptr
115+
// CHECK: %[[ARGS_ALLOCA:.*]] = alloca ptr, align 8, addrspace(5)
116+
// CHECK: %[[ARGS_ASCAST:.*]] = addrspacecast ptr addrspace(5) %[[ARGS_ALLOCA]] to ptr
117+
// CHECK: store i32 %[[ADDR]], ptr %[[ADDR_ASCAST]]
118+
// CHECK: store i32 0, ptr %[[ZERO_ASCAST]]
119+
// CHECK: call void @__kmpc_get_shared_variables(ptr %[[ARGS_ASCAST]])
120+
// CHECK: %[[LOAD_ARGS:.*]] = load ptr, ptr %[[ARGS_ASCAST]], align 8
121+
// CHECK: %[[FIRST_ARG:.*]] = getelementptr inbounds ptr, ptr %[[LOAD_ARGS]], i64 0
122+
// CHECK: %[[STRUCTARG:.*]] = load ptr, ptr %[[FIRST_ARG]], align 8
123+
// CHECK: call void @[[FUNC1]](ptr %[[ADDR_ASCAST]], ptr %[[ZERO_ASCAST]], ptr %[[STRUCTARG]])
124+
125+
// CHECK: define internal void @[[FUNC2_WRAPPER]](i16 noundef zeroext %{{.*}}, i32 noundef %{{.*}})
126+
// CHECK-NOT: define
127+
// CHECK: call void @[[FUNC_NUM_THREADS1]]({{.*}})

0 commit comments

Comments
 (0)