@@ -1426,6 +1426,86 @@ Error OpenMPIRBuilder::emitCancelationCheckImpl(
14261426 return Error::success ();
14271427}
14281428
1429+ // Create wrapper function used to gather the outlined function's argument
1430+ // structure from a shared buffer and to forward them to it when running in
1431+ // Generic mode.
1432+ //
1433+ // The outlined function is expected to receive 2 integer arguments followed by
1434+ // an optional pointer argument to an argument structure holding the rest.
1435+ static Function *createTargetParallelWrapper (OpenMPIRBuilder *OMPIRBuilder,
1436+ Function &OutlinedFn) {
1437+ size_t NumArgs = OutlinedFn.arg_size ();
1438+ assert ((NumArgs == 2 || NumArgs == 3 ) &&
1439+ " expected a 2-3 argument parallel outlined function" );
1440+ bool UseArgStruct = NumArgs == 3 ;
1441+
1442+ IRBuilder<> &Builder = OMPIRBuilder->Builder ;
1443+ IRBuilder<>::InsertPointGuard IPG (Builder);
1444+ auto *FnTy = FunctionType::get (Builder.getVoidTy (),
1445+ {Builder.getInt16Ty (), Builder.getInt32Ty ()},
1446+ /* isVarArg=*/ false );
1447+ auto *WrapperFn =
1448+ Function::Create (FnTy, GlobalValue::InternalLinkage,
1449+ OutlinedFn.getName () + " .wrapper" , OMPIRBuilder->M );
1450+
1451+ WrapperFn->addParamAttr (0 , Attribute::NoUndef);
1452+ WrapperFn->addParamAttr (0 , Attribute::ZExt);
1453+ WrapperFn->addParamAttr (1 , Attribute::NoUndef);
1454+
1455+ BasicBlock *EntryBB =
1456+ BasicBlock::Create (OMPIRBuilder->M .getContext (), " entry" , WrapperFn);
1457+ Builder.SetInsertPoint (EntryBB);
1458+
1459+ // Allocation.
1460+ Value *AddrAlloca = Builder.CreateAlloca (Builder.getInt32Ty (),
1461+ /* ArraySize=*/ nullptr , " addr" );
1462+ AddrAlloca = Builder.CreatePointerBitCastOrAddrSpaceCast (
1463+ AddrAlloca, Builder.getPtrTy (/* AddrSpace=*/ 0 ),
1464+ AddrAlloca->getName () + " .ascast" );
1465+
1466+ Value *ZeroAlloca = Builder.CreateAlloca (Builder.getInt32Ty (),
1467+ /* ArraySize=*/ nullptr , " zero" );
1468+ ZeroAlloca = Builder.CreatePointerBitCastOrAddrSpaceCast (
1469+ ZeroAlloca, Builder.getPtrTy (/* AddrSpace=*/ 0 ),
1470+ ZeroAlloca->getName () + " .ascast" );
1471+
1472+ Value *ArgsAlloca = nullptr ;
1473+ if (UseArgStruct) {
1474+ ArgsAlloca = Builder.CreateAlloca (Builder.getPtrTy (),
1475+ /* ArraySize=*/ nullptr , " global_args" );
1476+ ArgsAlloca = Builder.CreatePointerBitCastOrAddrSpaceCast (
1477+ ArgsAlloca, Builder.getPtrTy (/* AddrSpace=*/ 0 ),
1478+ ArgsAlloca->getName () + " .ascast" );
1479+ }
1480+
1481+ // Initialization.
1482+ Builder.CreateStore (WrapperFn->getArg (1 ), AddrAlloca);
1483+ Builder.CreateStore (Builder.getInt32 (0 ), ZeroAlloca);
1484+ if (UseArgStruct) {
1485+ Builder.CreateCall (
1486+ OMPIRBuilder->getOrCreateRuntimeFunctionPtr (
1487+ llvm::omp::RuntimeFunction::OMPRTL___kmpc_get_shared_variables),
1488+ {ArgsAlloca});
1489+ }
1490+
1491+ SmallVector<Value *, 3 > Args{AddrAlloca, ZeroAlloca};
1492+
1493+ // Load structArg from global_args.
1494+ if (UseArgStruct) {
1495+ Value *StructArg = Builder.CreateLoad (Builder.getPtrTy (), ArgsAlloca);
1496+ StructArg = Builder.CreateInBoundsGEP (Builder.getPtrTy (), StructArg,
1497+ {Builder.getInt64 (0 )});
1498+ StructArg = Builder.CreateLoad (Builder.getPtrTy (), StructArg, " structArg" );
1499+ Args.push_back (StructArg);
1500+ }
1501+
1502+ // Call the outlined function holding the parallel body.
1503+ Builder.CreateCall (&OutlinedFn, Args);
1504+ Builder.CreateRetVoid ();
1505+
1506+ return WrapperFn;
1507+ }
1508+
14291509// Callback used to create OpenMP runtime calls to support
14301510// omp parallel clause for the device.
14311511// We need to use this callback to replace call to the OutlinedFn in OuterFn
@@ -1435,6 +1515,10 @@ static void targetParallelCallback(
14351515 BasicBlock *OuterAllocaBB, Value *Ident, Value *IfCondition,
14361516 Value *NumThreads, Instruction *PrivTID, AllocaInst *PrivTIDAddr,
14371517 Value *ThreadID, const SmallVector<Instruction *, 4 > &ToBeDeleted) {
1518+ assert (OutlinedFn.arg_size () >= 2 &&
1519+ " Expected at least tid and bounded tid as arguments" );
1520+ unsigned NumCapturedVars = OutlinedFn.arg_size () - /* tid & bounded tid */ 2 ;
1521+
14381522 // Add some known attributes.
14391523 IRBuilder<> &Builder = OMPIRBuilder->Builder ;
14401524 OutlinedFn.addParamAttr (0 , Attribute::NoAlias);
@@ -1443,17 +1527,12 @@ static void targetParallelCallback(
14431527 OutlinedFn.addParamAttr (1 , Attribute::NoUndef);
14441528 OutlinedFn.addFnAttr (Attribute::NoUnwind);
14451529
1446- assert (OutlinedFn.arg_size () >= 2 &&
1447- " Expected at least tid and bounded tid as arguments" );
1448- unsigned NumCapturedVars = OutlinedFn.arg_size () - /* tid & bounded tid */ 2 ;
1449-
14501530 CallInst *CI = cast<CallInst>(OutlinedFn.user_back ());
14511531 assert (CI && " Expected call instruction to outlined function" );
14521532 CI->getParent ()->setName (" omp_parallel" );
14531533
14541534 Builder.SetInsertPoint (CI);
14551535 Type *PtrTy = OMPIRBuilder->VoidPtr ;
1456- Value *NullPtrValue = Constant::getNullValue (PtrTy);
14571536
14581537 // Add alloca for kernel args
14591538 OpenMPIRBuilder ::InsertPointTy CurrentIP = Builder.saveIP ();
@@ -1479,6 +1558,15 @@ static void targetParallelCallback(
14791558 IfCondition ? Builder.CreateSExtOrTrunc (IfCondition, OMPIRBuilder->Int32 )
14801559 : Builder.getInt32 (1 );
14811560
1561+ // If this is not a Generic kernel, we can skip generating the wrapper.
1562+ std::optional<omp::OMPTgtExecModeFlags> ExecMode =
1563+ getTargetKernelExecMode (*OuterFn);
1564+ Value *WrapperFn;
1565+ if (ExecMode && *ExecMode & OMP_TGT_EXEC_MODE_GENERIC)
1566+ WrapperFn = createTargetParallelWrapper (OMPIRBuilder, OutlinedFn);
1567+ else
1568+ WrapperFn = Constant::getNullValue (PtrTy);
1569+
14821570 // Build kmpc_parallel_51 call
14831571 Value *Parallel51CallArgs[] = {
14841572 /* identifier*/ Ident,
@@ -1487,7 +1575,7 @@ static void targetParallelCallback(
14871575 /* number of threads */ NumThreads ? NumThreads : Builder.getInt32 (-1 ),
14881576 /* Proc bind */ Builder.getInt32 (-1 ),
14891577 /* outlined function */ &OutlinedFn,
1490- /* wrapper function */ NullPtrValue ,
1578+ /* wrapper function */ WrapperFn ,
14911579 /* arguments of the outlined funciton*/ Args,
14921580 /* number of arguments */ Builder.getInt64 (NumCapturedVars)};
14931581
0 commit comments