@@ -1323,6 +1323,86 @@ Error OpenMPIRBuilder::emitCancelationCheckImpl(
13231323 return Error::success ();
13241324}
13251325
1326+ // Create wrapper function used to gather the outlined function's argument
1327+ // structure from a shared buffer and to forward them to it when running in
1328+ // Generic mode.
1329+ //
1330+ // The outlined function is expected to receive 2 integer arguments followed by
1331+ // an optional pointer argument to an argument structure holding the rest.
1332+ static Function *createTargetParallelWrapper (OpenMPIRBuilder *OMPIRBuilder,
1333+ Function &OutlinedFn) {
1334+ size_t NumArgs = OutlinedFn.arg_size ();
1335+ assert ((NumArgs == 2 || NumArgs == 3 ) &&
1336+ " expected a 2-3 argument parallel outlined function" );
1337+ bool UseArgStruct = NumArgs == 3 ;
1338+
1339+ IRBuilder<> &Builder = OMPIRBuilder->Builder ;
1340+ IRBuilder<>::InsertPointGuard IPG (Builder);
1341+ auto *FnTy = FunctionType::get (Builder.getVoidTy (),
1342+ {Builder.getInt16Ty (), Builder.getInt32Ty ()},
1343+ /* isVarArg=*/ false );
1344+ auto *WrapperFn =
1345+ Function::Create (FnTy, GlobalValue::InternalLinkage,
1346+ OutlinedFn.getName () + " .wrapper" , OMPIRBuilder->M );
1347+
1348+ WrapperFn->addParamAttr (0 , Attribute::NoUndef);
1349+ WrapperFn->addParamAttr (0 , Attribute::ZExt);
1350+ WrapperFn->addParamAttr (1 , Attribute::NoUndef);
1351+
1352+ BasicBlock *EntryBB =
1353+ BasicBlock::Create (OMPIRBuilder->M .getContext (), " entry" , WrapperFn);
1354+ Builder.SetInsertPoint (EntryBB);
1355+
1356+ // Allocation.
1357+ Value *AddrAlloca = Builder.CreateAlloca (Builder.getInt32Ty (),
1358+ /* ArraySize=*/ nullptr , " addr" );
1359+ AddrAlloca = Builder.CreatePointerBitCastOrAddrSpaceCast (
1360+ AddrAlloca, Builder.getPtrTy (/* AddrSpace=*/ 0 ),
1361+ AddrAlloca->getName () + " .ascast" );
1362+
1363+ Value *ZeroAlloca = Builder.CreateAlloca (Builder.getInt32Ty (),
1364+ /* ArraySize=*/ nullptr , " zero" );
1365+ ZeroAlloca = Builder.CreatePointerBitCastOrAddrSpaceCast (
1366+ ZeroAlloca, Builder.getPtrTy (/* AddrSpace=*/ 0 ),
1367+ ZeroAlloca->getName () + " .ascast" );
1368+
1369+ Value *ArgsAlloca = nullptr ;
1370+ if (UseArgStruct) {
1371+ ArgsAlloca = Builder.CreateAlloca (Builder.getPtrTy (),
1372+ /* ArraySize=*/ nullptr , " global_args" );
1373+ ArgsAlloca = Builder.CreatePointerBitCastOrAddrSpaceCast (
1374+ ArgsAlloca, Builder.getPtrTy (/* AddrSpace=*/ 0 ),
1375+ ArgsAlloca->getName () + " .ascast" );
1376+ }
1377+
1378+ // Initialization.
1379+ Builder.CreateStore (WrapperFn->getArg (1 ), AddrAlloca);
1380+ Builder.CreateStore (Builder.getInt32 (0 ), ZeroAlloca);
1381+ if (UseArgStruct) {
1382+ Builder.CreateCall (
1383+ OMPIRBuilder->getOrCreateRuntimeFunctionPtr (
1384+ llvm::omp::RuntimeFunction::OMPRTL___kmpc_get_shared_variables),
1385+ {ArgsAlloca});
1386+ }
1387+
1388+ SmallVector<Value *, 3 > Args{AddrAlloca, ZeroAlloca};
1389+
1390+ // Load structArg from global_args.
1391+ if (UseArgStruct) {
1392+ Value *StructArg = Builder.CreateLoad (Builder.getPtrTy (), ArgsAlloca);
1393+ StructArg = Builder.CreateInBoundsGEP (Builder.getPtrTy (), StructArg,
1394+ {Builder.getInt64 (0 )});
1395+ StructArg = Builder.CreateLoad (Builder.getPtrTy (), StructArg, " structArg" );
1396+ Args.push_back (StructArg);
1397+ }
1398+
1399+ // Call the outlined function holding the parallel body.
1400+ Builder.CreateCall (&OutlinedFn, Args);
1401+ Builder.CreateRetVoid ();
1402+
1403+ return WrapperFn;
1404+ }
1405+
13261406// Callback used to create OpenMP runtime calls to support
13271407// omp parallel clause for the device.
13281408// We need to use this callback to replace call to the OutlinedFn in OuterFn
@@ -1332,6 +1412,10 @@ static void targetParallelCallback(
13321412 BasicBlock *OuterAllocaBB, Value *Ident, Value *IfCondition,
13331413 Value *NumThreads, Instruction *PrivTID, AllocaInst *PrivTIDAddr,
13341414 Value *ThreadID, const SmallVector<Instruction *, 4 > &ToBeDeleted) {
1415+ assert (OutlinedFn.arg_size () >= 2 &&
1416+ " Expected at least tid and bounded tid as arguments" );
1417+ unsigned NumCapturedVars = OutlinedFn.arg_size () - /* tid & bounded tid */ 2 ;
1418+
13351419 // Add some known attributes.
13361420 IRBuilder<> &Builder = OMPIRBuilder->Builder ;
13371421 OutlinedFn.addParamAttr (0 , Attribute::NoAlias);
@@ -1340,17 +1424,12 @@ static void targetParallelCallback(
13401424 OutlinedFn.addParamAttr (1 , Attribute::NoUndef);
13411425 OutlinedFn.addFnAttr (Attribute::NoUnwind);
13421426
1343- assert (OutlinedFn.arg_size () >= 2 &&
1344- " Expected at least tid and bounded tid as arguments" );
1345- unsigned NumCapturedVars = OutlinedFn.arg_size () - /* tid & bounded tid */ 2 ;
1346-
13471427 CallInst *CI = cast<CallInst>(OutlinedFn.user_back ());
13481428 assert (CI && " Expected call instruction to outlined function" );
13491429 CI->getParent ()->setName (" omp_parallel" );
13501430
13511431 Builder.SetInsertPoint (CI);
13521432 Type *PtrTy = OMPIRBuilder->VoidPtr ;
1353- Value *NullPtrValue = Constant::getNullValue (PtrTy);
13541433
13551434 // Add alloca for kernel args
13561435 OpenMPIRBuilder ::InsertPointTy CurrentIP = Builder.saveIP ();
@@ -1376,6 +1455,15 @@ static void targetParallelCallback(
13761455 IfCondition ? Builder.CreateSExtOrTrunc (IfCondition, OMPIRBuilder->Int32 )
13771456 : Builder.getInt32 (1 );
13781457
1458+ // If this is not a Generic kernel, we can skip generating the wrapper.
1459+ std::optional<omp::OMPTgtExecModeFlags> ExecMode =
1460+ getTargetKernelExecMode (*OuterFn);
1461+ Value *WrapperFn;
1462+ if (ExecMode && *ExecMode & OMP_TGT_EXEC_MODE_GENERIC)
1463+ WrapperFn = createTargetParallelWrapper (OMPIRBuilder, OutlinedFn);
1464+ else
1465+ WrapperFn = Constant::getNullValue (PtrTy);
1466+
13791467 // Build kmpc_parallel_51 call
13801468 Value *Parallel51CallArgs[] = {
13811469 /* identifier*/ Ident,
@@ -1384,7 +1472,7 @@ static void targetParallelCallback(
13841472 /* number of threads */ NumThreads ? NumThreads : Builder.getInt32 (-1 ),
13851473 /* Proc bind */ Builder.getInt32 (-1 ),
13861474 /* outlined function */ &OutlinedFn,
1387- /* wrapper function */ NullPtrValue ,
1475+ /* wrapper function */ WrapperFn ,
13881476 /* arguments of the outlined funciton*/ Args,
13891477 /* number of arguments */ Builder.getInt64 (NumCapturedVars)};
13901478
0 commit comments