Skip to content

Commit 9533653

Browse files
committed
[OpenMP][OMPIRBuilder] Support parallel in Generic kernels
This patch introduces codegen logic to produce a wrapper function argument for the `__kmpc_parallel_51` DeviceRTL function needed to handle arguments passed using device shared memory in Generic mode.
1 parent 6bcb74a commit 9533653

File tree

2 files changed

+116
-9
lines changed

2 files changed

+116
-9
lines changed

llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp

Lines changed: 94 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1426,6 +1426,86 @@ Error OpenMPIRBuilder::emitCancelationCheckImpl(
14261426
return Error::success();
14271427
}
14281428

1429+
// Create wrapper function used to gather the outlined function's argument
1430+
// structure from a shared buffer and to forward them to it when running in
1431+
// Generic mode.
1432+
//
1433+
// The outlined function is expected to receive 2 integer arguments followed by
1434+
// an optional pointer argument to an argument structure holding the rest.
1435+
static Function *createTargetParallelWrapper(OpenMPIRBuilder *OMPIRBuilder,
1436+
Function &OutlinedFn) {
1437+
size_t NumArgs = OutlinedFn.arg_size();
1438+
assert((NumArgs == 2 || NumArgs == 3) &&
1439+
"expected a 2-3 argument parallel outlined function");
1440+
bool UseArgStruct = NumArgs == 3;
1441+
1442+
IRBuilder<> &Builder = OMPIRBuilder->Builder;
1443+
IRBuilder<>::InsertPointGuard IPG(Builder);
1444+
auto *FnTy = FunctionType::get(Builder.getVoidTy(),
1445+
{Builder.getInt16Ty(), Builder.getInt32Ty()},
1446+
/*isVarArg=*/false);
1447+
auto *WrapperFn =
1448+
Function::Create(FnTy, GlobalValue::InternalLinkage,
1449+
OutlinedFn.getName() + ".wrapper", OMPIRBuilder->M);
1450+
1451+
WrapperFn->addParamAttr(0, Attribute::NoUndef);
1452+
WrapperFn->addParamAttr(0, Attribute::ZExt);
1453+
WrapperFn->addParamAttr(1, Attribute::NoUndef);
1454+
1455+
BasicBlock *EntryBB =
1456+
BasicBlock::Create(OMPIRBuilder->M.getContext(), "entry", WrapperFn);
1457+
Builder.SetInsertPoint(EntryBB);
1458+
1459+
// Allocation.
1460+
Value *AddrAlloca = Builder.CreateAlloca(Builder.getInt32Ty(),
1461+
/*ArraySize=*/nullptr, "addr");
1462+
AddrAlloca = Builder.CreatePointerBitCastOrAddrSpaceCast(
1463+
AddrAlloca, Builder.getPtrTy(/*AddrSpace=*/0),
1464+
AddrAlloca->getName() + ".ascast");
1465+
1466+
Value *ZeroAlloca = Builder.CreateAlloca(Builder.getInt32Ty(),
1467+
/*ArraySize=*/nullptr, "zero");
1468+
ZeroAlloca = Builder.CreatePointerBitCastOrAddrSpaceCast(
1469+
ZeroAlloca, Builder.getPtrTy(/*AddrSpace=*/0),
1470+
ZeroAlloca->getName() + ".ascast");
1471+
1472+
Value *ArgsAlloca = nullptr;
1473+
if (UseArgStruct) {
1474+
ArgsAlloca = Builder.CreateAlloca(Builder.getPtrTy(),
1475+
/*ArraySize=*/nullptr, "global_args");
1476+
ArgsAlloca = Builder.CreatePointerBitCastOrAddrSpaceCast(
1477+
ArgsAlloca, Builder.getPtrTy(/*AddrSpace=*/0),
1478+
ArgsAlloca->getName() + ".ascast");
1479+
}
1480+
1481+
// Initialization.
1482+
Builder.CreateStore(WrapperFn->getArg(1), AddrAlloca);
1483+
Builder.CreateStore(Builder.getInt32(0), ZeroAlloca);
1484+
if (UseArgStruct) {
1485+
Builder.CreateCall(
1486+
OMPIRBuilder->getOrCreateRuntimeFunctionPtr(
1487+
llvm::omp::RuntimeFunction::OMPRTL___kmpc_get_shared_variables),
1488+
{ArgsAlloca});
1489+
}
1490+
1491+
SmallVector<Value *, 3> Args{AddrAlloca, ZeroAlloca};
1492+
1493+
// Load structArg from global_args.
1494+
if (UseArgStruct) {
1495+
Value *StructArg = Builder.CreateLoad(Builder.getPtrTy(), ArgsAlloca);
1496+
StructArg = Builder.CreateInBoundsGEP(Builder.getPtrTy(), StructArg,
1497+
{Builder.getInt64(0)});
1498+
StructArg = Builder.CreateLoad(Builder.getPtrTy(), StructArg, "structArg");
1499+
Args.push_back(StructArg);
1500+
}
1501+
1502+
// Call the outlined function holding the parallel body.
1503+
Builder.CreateCall(&OutlinedFn, Args);
1504+
Builder.CreateRetVoid();
1505+
1506+
return WrapperFn;
1507+
}
1508+
14291509
// Callback used to create OpenMP runtime calls to support
14301510
// omp parallel clause for the device.
14311511
// We need to use this callback to replace call to the OutlinedFn in OuterFn
@@ -1435,6 +1515,10 @@ static void targetParallelCallback(
14351515
BasicBlock *OuterAllocaBB, Value *Ident, Value *IfCondition,
14361516
Value *NumThreads, Instruction *PrivTID, AllocaInst *PrivTIDAddr,
14371517
Value *ThreadID, const SmallVector<Instruction *, 4> &ToBeDeleted) {
1518+
assert(OutlinedFn.arg_size() >= 2 &&
1519+
"Expected at least tid and bounded tid as arguments");
1520+
unsigned NumCapturedVars = OutlinedFn.arg_size() - /* tid & bounded tid */ 2;
1521+
14381522
// Add some known attributes.
14391523
IRBuilder<> &Builder = OMPIRBuilder->Builder;
14401524
OutlinedFn.addParamAttr(0, Attribute::NoAlias);
@@ -1443,17 +1527,12 @@ static void targetParallelCallback(
14431527
OutlinedFn.addParamAttr(1, Attribute::NoUndef);
14441528
OutlinedFn.addFnAttr(Attribute::NoUnwind);
14451529

1446-
assert(OutlinedFn.arg_size() >= 2 &&
1447-
"Expected at least tid and bounded tid as arguments");
1448-
unsigned NumCapturedVars = OutlinedFn.arg_size() - /* tid & bounded tid */ 2;
1449-
14501530
CallInst *CI = cast<CallInst>(OutlinedFn.user_back());
14511531
assert(CI && "Expected call instruction to outlined function");
14521532
CI->getParent()->setName("omp_parallel");
14531533

14541534
Builder.SetInsertPoint(CI);
14551535
Type *PtrTy = OMPIRBuilder->VoidPtr;
1456-
Value *NullPtrValue = Constant::getNullValue(PtrTy);
14571536

14581537
// Add alloca for kernel args
14591538
OpenMPIRBuilder ::InsertPointTy CurrentIP = Builder.saveIP();
@@ -1479,6 +1558,15 @@ static void targetParallelCallback(
14791558
IfCondition ? Builder.CreateSExtOrTrunc(IfCondition, OMPIRBuilder->Int32)
14801559
: Builder.getInt32(1);
14811560

1561+
// If this is not a Generic kernel, we can skip generating the wrapper.
1562+
std::optional<omp::OMPTgtExecModeFlags> ExecMode =
1563+
getTargetKernelExecMode(*OuterFn);
1564+
Value *WrapperFn;
1565+
if (ExecMode && *ExecMode & OMP_TGT_EXEC_MODE_GENERIC)
1566+
WrapperFn = createTargetParallelWrapper(OMPIRBuilder, OutlinedFn);
1567+
else
1568+
WrapperFn = Constant::getNullValue(PtrTy);
1569+
14821570
// Build kmpc_parallel_51 call
14831571
Value *Parallel51CallArgs[] = {
14841572
/* identifier*/ Ident,
@@ -1487,7 +1575,7 @@ static void targetParallelCallback(
14871575
/* number of threads */ NumThreads ? NumThreads : Builder.getInt32(-1),
14881576
/* Proc bind */ Builder.getInt32(-1),
14891577
/* outlined function */ &OutlinedFn,
1490-
/* wrapper function */ NullPtrValue,
1578+
/* wrapper function */ WrapperFn,
14911579
/* arguments of the outlined funciton*/ Args,
14921580
/* number of arguments */ Builder.getInt64(NumCapturedVars)};
14931581

mlir/test/Target/LLVMIR/omptarget-parallel-llvm.mlir

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"dlti.alloca_memo
6969
// CHECK: store ptr %[[TMP6]], ptr %[[GEP_]], align 8
7070
// CHECK: %[[TMP7:.*]] = getelementptr inbounds [1 x ptr], ptr %[[TMP2]], i64 0, i64 0
7171
// CHECK: store ptr %[[STRUCTARG]], ptr %[[TMP7]], align 8
72-
// CHECK: call void @__kmpc_parallel_51(ptr addrspacecast (ptr addrspace(1) @[[GLOB1]] to ptr), i32 %[[OMP_GLOBAL_THREAD_NUM]], i32 1, i32 -1, i32 -1, ptr @[[FUNC1:.*]], ptr null, ptr %[[TMP2]], i64 1)
72+
// CHECK: call void @__kmpc_parallel_51(ptr addrspacecast (ptr addrspace(1) @[[GLOB1]] to ptr), i32 %[[OMP_GLOBAL_THREAD_NUM]], i32 1, i32 -1, i32 -1, ptr @[[FUNC1:.*]], ptr @[[FUNC1_WRAPPER:.*]], ptr %[[TMP2]], i64 1)
7373
// CHECK: call void @__kmpc_free_shared(ptr %[[STRUCTARG]], i64 8)
7474
// CHECK: call void @__kmpc_target_deinit()
7575

@@ -84,7 +84,7 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"dlti.alloca_memo
8484
// CHECK: call void @__kmpc_parallel_51(ptr addrspacecast (
8585
// CHECK-SAME: ptr addrspace(1) @[[NUM_THREADS_GLOB:[0-9]+]] to ptr),
8686
// CHECK-SAME: i32 [[NUM_THREADS_TMP0:%.*]], i32 1, i32 156,
87-
// CHECK-SAME: i32 -1, ptr [[FUNC_NUM_THREADS1:@.*]], ptr null, ptr [[NUM_THREADS_TMP1:%.*]], i64 1)
87+
// CHECK-SAME: i32 -1, ptr @[[FUNC_NUM_THREADS1:.*]], ptr @[[FUNC2_WRAPPER:.*]], ptr [[NUM_THREADS_TMP1:%.*]], i64 1)
8888

8989
// One of the arguments of kmpc_parallel_51 function is responsible for handling if clause
9090
// of omp parallel construct for target region. If this argument is nonzero,
@@ -105,4 +105,23 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"dlti.alloca_memo
105105
// CHECK: call void @__kmpc_parallel_51(ptr addrspacecast (
106106
// CHECK-SAME: ptr addrspace(1) {{.*}} to ptr),
107107
// CHECK-SAME: i32 {{.*}}, i32 %[[IFCOND_TMP4]], i32 -1,
108-
// CHECK-SAME: i32 -1, ptr {{.*}}, ptr null, ptr {{.*}}, i64 1)
108+
// CHECK-SAME: i32 -1, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, i64 1)
109+
110+
// CHECK: define internal void @[[FUNC1_WRAPPER]](i16 noundef zeroext %{{.*}}, i32 noundef %[[ADDR:.*]])
111+
// CHECK: %[[ADDR_ALLOCA:.*]] = alloca i32, align 4, addrspace(5)
112+
// CHECK: %[[ADDR_ASCAST:.*]] = addrspacecast ptr addrspace(5) %[[ADDR_ALLOCA]] to ptr
113+
// CHECK: %[[ZERO_ALLOCA:.*]] = alloca i32, align 4, addrspace(5)
114+
// CHECK: %[[ZERO_ASCAST:.*]] = addrspacecast ptr addrspace(5) %[[ZERO_ALLOCA]] to ptr
115+
// CHECK: %[[ARGS_ALLOCA:.*]] = alloca ptr, align 8, addrspace(5)
116+
// CHECK: %[[ARGS_ASCAST:.*]] = addrspacecast ptr addrspace(5) %[[ARGS_ALLOCA]] to ptr
117+
// CHECK: store i32 %[[ADDR]], ptr %[[ADDR_ASCAST]]
118+
// CHECK: store i32 0, ptr %[[ZERO_ASCAST]]
119+
// CHECK: call void @__kmpc_get_shared_variables(ptr %[[ARGS_ASCAST]])
120+
// CHECK: %[[LOAD_ARGS:.*]] = load ptr, ptr %[[ARGS_ASCAST]], align 8
121+
// CHECK: %[[FIRST_ARG:.*]] = getelementptr inbounds ptr, ptr %[[LOAD_ARGS]], i64 0
122+
// CHECK: %[[STRUCTARG:.*]] = load ptr, ptr %[[FIRST_ARG]], align 8
123+
// CHECK: call void @[[FUNC1]](ptr %[[ADDR_ASCAST]], ptr %[[ZERO_ASCAST]], ptr %[[STRUCTARG]])
124+
125+
// CHECK: define internal void @[[FUNC2_WRAPPER]](i16 noundef zeroext %{{.*}}, i32 noundef %{{.*}})
126+
// CHECK-NOT: define
127+
// CHECK: call void @[[FUNC_NUM_THREADS1]]({{.*}})

0 commit comments

Comments
 (0)