@@ -1334,6 +1334,86 @@ Error OpenMPIRBuilder::emitCancelationCheckImpl(
13341334 return Error::success ();
13351335}
13361336
1337+ // Create wrapper function used to gather the outlined function's argument
1338+ // structure from a shared buffer and to forward them to it when running in
1339+ // Generic mode.
1340+ //
1341+ // The outlined function is expected to receive 2 integer arguments followed by
1342+ // an optional pointer argument to an argument structure holding the rest.
1343+ static Function *createTargetParallelWrapper (OpenMPIRBuilder *OMPIRBuilder,
1344+ Function &OutlinedFn) {
1345+ size_t NumArgs = OutlinedFn.arg_size ();
1346+ assert ((NumArgs == 2 || NumArgs == 3 ) &&
1347+ " expected a 2-3 argument parallel outlined function" );
1348+ bool UseArgStruct = NumArgs == 3 ;
1349+
1350+ IRBuilder<> &Builder = OMPIRBuilder->Builder ;
1351+ IRBuilder<>::InsertPointGuard IPG (Builder);
1352+ auto *FnTy = FunctionType::get (Builder.getVoidTy (),
1353+ {Builder.getInt16Ty (), Builder.getInt32Ty ()},
1354+ /* isVarArg=*/ false );
1355+ auto *WrapperFn =
1356+ Function::Create (FnTy, GlobalValue::InternalLinkage,
1357+ OutlinedFn.getName () + " .wrapper" , OMPIRBuilder->M );
1358+
1359+ WrapperFn->addParamAttr (0 , Attribute::NoUndef);
1360+ WrapperFn->addParamAttr (0 , Attribute::ZExt);
1361+ WrapperFn->addParamAttr (1 , Attribute::NoUndef);
1362+
1363+ BasicBlock *EntryBB =
1364+ BasicBlock::Create (OMPIRBuilder->M .getContext (), " entry" , WrapperFn);
1365+ Builder.SetInsertPoint (EntryBB);
1366+
1367+ // Allocation.
1368+ Value *AddrAlloca = Builder.CreateAlloca (Builder.getInt32Ty (),
1369+ /* ArraySize=*/ nullptr , " addr" );
1370+ AddrAlloca = Builder.CreatePointerBitCastOrAddrSpaceCast (
1371+ AddrAlloca, Builder.getPtrTy (/* AddrSpace=*/ 0 ),
1372+ AddrAlloca->getName () + " .ascast" );
1373+
1374+ Value *ZeroAlloca = Builder.CreateAlloca (Builder.getInt32Ty (),
1375+ /* ArraySize=*/ nullptr , " zero" );
1376+ ZeroAlloca = Builder.CreatePointerBitCastOrAddrSpaceCast (
1377+ ZeroAlloca, Builder.getPtrTy (/* AddrSpace=*/ 0 ),
1378+ ZeroAlloca->getName () + " .ascast" );
1379+
1380+ Value *ArgsAlloca = nullptr ;
1381+ if (UseArgStruct) {
1382+ ArgsAlloca = Builder.CreateAlloca (Builder.getPtrTy (),
1383+ /* ArraySize=*/ nullptr , " global_args" );
1384+ ArgsAlloca = Builder.CreatePointerBitCastOrAddrSpaceCast (
1385+ ArgsAlloca, Builder.getPtrTy (/* AddrSpace=*/ 0 ),
1386+ ArgsAlloca->getName () + " .ascast" );
1387+ }
1388+
1389+ // Initialization.
1390+ Builder.CreateStore (WrapperFn->getArg (1 ), AddrAlloca);
1391+ Builder.CreateStore (Builder.getInt32 (0 ), ZeroAlloca);
1392+ if (UseArgStruct) {
1393+ Builder.CreateCall (
1394+ OMPIRBuilder->getOrCreateRuntimeFunctionPtr (
1395+ llvm::omp::RuntimeFunction::OMPRTL___kmpc_get_shared_variables),
1396+ {ArgsAlloca});
1397+ }
1398+
1399+ SmallVector<Value *, 3 > Args{AddrAlloca, ZeroAlloca};
1400+
1401+ // Load structArg from global_args.
1402+ if (UseArgStruct) {
1403+ Value *StructArg = Builder.CreateLoad (Builder.getPtrTy (), ArgsAlloca);
1404+ StructArg = Builder.CreateInBoundsGEP (Builder.getPtrTy (), StructArg,
1405+ {Builder.getInt64 (0 )});
1406+ StructArg = Builder.CreateLoad (Builder.getPtrTy (), StructArg, " structArg" );
1407+ Args.push_back (StructArg);
1408+ }
1409+
1410+ // Call the outlined function holding the parallel body.
1411+ Builder.CreateCall (&OutlinedFn, Args);
1412+ Builder.CreateRetVoid ();
1413+
1414+ return WrapperFn;
1415+ }
1416+
13371417// Callback used to create OpenMP runtime calls to support
13381418// omp parallel clause for the device.
13391419// We need to use this callback to replace call to the OutlinedFn in OuterFn
@@ -1343,6 +1423,10 @@ static void targetParallelCallback(
13431423 BasicBlock *OuterAllocaBB, Value *Ident, Value *IfCondition,
13441424 Value *NumThreads, Instruction *PrivTID, AllocaInst *PrivTIDAddr,
13451425 Value *ThreadID, const SmallVector<Instruction *, 4 > &ToBeDeleted) {
1426+ assert (OutlinedFn.arg_size () >= 2 &&
1427+ " Expected at least tid and bounded tid as arguments" );
1428+ unsigned NumCapturedVars = OutlinedFn.arg_size () - /* tid & bounded tid */ 2 ;
1429+
13461430 // Add some known attributes.
13471431 IRBuilder<> &Builder = OMPIRBuilder->Builder ;
13481432 OutlinedFn.addParamAttr (0 , Attribute::NoAlias);
@@ -1351,17 +1435,12 @@ static void targetParallelCallback(
13511435 OutlinedFn.addParamAttr (1 , Attribute::NoUndef);
13521436 OutlinedFn.addFnAttr (Attribute::NoUnwind);
13531437
1354- assert (OutlinedFn.arg_size () >= 2 &&
1355- " Expected at least tid and bounded tid as arguments" );
1356- unsigned NumCapturedVars = OutlinedFn.arg_size () - /* tid & bounded tid */ 2 ;
1357-
13581438 CallInst *CI = cast<CallInst>(OutlinedFn.user_back ());
13591439 assert (CI && " Expected call instruction to outlined function" );
13601440 CI->getParent ()->setName (" omp_parallel" );
13611441
13621442 Builder.SetInsertPoint (CI);
13631443 Type *PtrTy = OMPIRBuilder->VoidPtr ;
1364- Value *NullPtrValue = Constant::getNullValue (PtrTy);
13651444
13661445 // Add alloca for kernel args
13671446 OpenMPIRBuilder ::InsertPointTy CurrentIP = Builder.saveIP ();
@@ -1387,6 +1466,15 @@ static void targetParallelCallback(
13871466 IfCondition ? Builder.CreateSExtOrTrunc (IfCondition, OMPIRBuilder->Int32 )
13881467 : Builder.getInt32 (1 );
13891468
1469+ // If this is not a Generic kernel, we can skip generating the wrapper.
1470+ std::optional<omp::OMPTgtExecModeFlags> ExecMode =
1471+ getTargetKernelExecMode (*OuterFn);
1472+ Value *WrapperFn;
1473+ if (ExecMode && *ExecMode & OMP_TGT_EXEC_MODE_GENERIC)
1474+ WrapperFn = createTargetParallelWrapper (OMPIRBuilder, OutlinedFn);
1475+ else
1476+ WrapperFn = Constant::getNullValue (PtrTy);
1477+
13901478 // Build kmpc_parallel_51 call
13911479 Value *Parallel51CallArgs[] = {
13921480 /* identifier*/ Ident,
@@ -1395,7 +1483,7 @@ static void targetParallelCallback(
13951483 /* number of threads */ NumThreads ? NumThreads : Builder.getInt32 (-1 ),
13961484 /* Proc bind */ Builder.getInt32 (-1 ),
13971485 /* outlined function */ &OutlinedFn,
1398- /* wrapper function */ NullPtrValue ,
1486+ /* wrapper function */ WrapperFn ,
13991487 /* arguments of the outlined funciton*/ Args,
14001488 /* number of arguments */ Builder.getInt64 (NumCapturedVars)};
14011489
0 commit comments