@@ -651,6 +651,13 @@ void OpenMPIRBuilder::finalize(Function *Fn) {
651651
652652 Function *OuterFn = OI.getFunction ();
653653 CodeExtractorAnalysisCache CEAC (*OuterFn);
654+ // If we generate code for the target device, we need to allocate
655+ // struct for aggregate params in the device default alloca address space.
656+ // OpenMP runtime requires that the params of the extracted functions are
657+ // passed as zero address space pointers. This flag ensures that
658+ // CodeExtractor generates correct code for extracted functions
659+ // which are used by OpenMP runtime.
660+ bool ArgsInZeroAddressSpace = Config.isTargetDevice ();
654661 CodeExtractor Extractor (Blocks, /* DominatorTree */ nullptr ,
655662 /* AggregateArgs */ true ,
656663 /* BlockFrequencyInfo */ nullptr ,
@@ -659,7 +666,7 @@ void OpenMPIRBuilder::finalize(Function *Fn) {
659666 /* AllowVarArgs */ true ,
660667 /* AllowAlloca */ true ,
661668 /* AllocaBlock*/ OI.OuterAllocaBB ,
662- /* Suffix */ " .omp_par" );
669+ /* Suffix */ " .omp_par" , ArgsInZeroAddressSpace );
663670
664671 LLVM_DEBUG (dbgs () << " Before outlining: " << *OuterFn << " \n " );
665672 LLVM_DEBUG (dbgs () << " Entry " << OI.EntryBB ->getName ()
@@ -1101,6 +1108,182 @@ void OpenMPIRBuilder::emitCancelationCheckImpl(Value *CancelFlag,
11011108 Builder.SetInsertPoint (NonCancellationBlock, NonCancellationBlock->begin ());
11021109}
11031110
1111+ // Callback used to create OpenMP runtime calls to support
1112+ // omp parallel clause for the device.
1113+ // We need to use this callback to replace call to the OutlinedFn in OuterFn
1114+ // by the call to the OpenMP DeviceRTL runtime function (kmpc_parallel_51)
1115+ static void targetParallelCallback (
1116+ OpenMPIRBuilder *OMPIRBuilder, Function &OutlinedFn, Function *OuterFn,
1117+ BasicBlock *OuterAllocaBB, Value *Ident, Value *IfCondition,
1118+ Value *NumThreads, Instruction *PrivTID, AllocaInst *PrivTIDAddr,
1119+ Value *ThreadID, const SmallVector<Instruction *, 4 > &ToBeDeleted) {
1120+ // Add some known attributes.
1121+ IRBuilder<> &Builder = OMPIRBuilder->Builder ;
1122+ OutlinedFn.addParamAttr (0 , Attribute::NoAlias);
1123+ OutlinedFn.addParamAttr (1 , Attribute::NoAlias);
1124+ OutlinedFn.addParamAttr (0 , Attribute::NoUndef);
1125+ OutlinedFn.addParamAttr (1 , Attribute::NoUndef);
1126+ OutlinedFn.addFnAttr (Attribute::NoUnwind);
1127+
1128+ assert (OutlinedFn.arg_size () >= 2 &&
1129+ " Expected at least tid and bounded tid as arguments" );
1130+ unsigned NumCapturedVars = OutlinedFn.arg_size () - /* tid & bounded tid */ 2 ;
1131+
1132+ CallInst *CI = cast<CallInst>(OutlinedFn.user_back ());
1133+ assert (CI && " Expected call instruction to outlined function" );
1134+ CI->getParent ()->setName (" omp_parallel" );
1135+
1136+ Builder.SetInsertPoint (CI);
1137+ Type *PtrTy = OMPIRBuilder->VoidPtr ;
1138+ Value *NullPtrValue = Constant::getNullValue (PtrTy);
1139+
1140+ // Add alloca for kernel args
1141+ OpenMPIRBuilder ::InsertPointTy CurrentIP = Builder.saveIP ();
1142+ Builder.SetInsertPoint (OuterAllocaBB, OuterAllocaBB->getFirstInsertionPt ());
1143+ AllocaInst *ArgsAlloca =
1144+ Builder.CreateAlloca (ArrayType::get (PtrTy, NumCapturedVars));
1145+ Value *Args = ArgsAlloca;
1146+ // Add address space cast if array for storing arguments is not allocated
1147+ // in address space 0
1148+ if (ArgsAlloca->getAddressSpace ())
1149+ Args = Builder.CreatePointerCast (ArgsAlloca, PtrTy);
1150+ Builder.restoreIP (CurrentIP);
1151+
1152+ // Store captured vars which are used by kmpc_parallel_51
1153+ for (unsigned Idx = 0 ; Idx < NumCapturedVars; Idx++) {
1154+ Value *V = *(CI->arg_begin () + 2 + Idx);
1155+ Value *StoreAddress = Builder.CreateConstInBoundsGEP2_64 (
1156+ ArrayType::get (PtrTy, NumCapturedVars), Args, 0 , Idx);
1157+ Builder.CreateStore (V, StoreAddress);
1158+ }
1159+
1160+ Value *Cond =
1161+ IfCondition ? Builder.CreateSExtOrTrunc (IfCondition, OMPIRBuilder->Int32 )
1162+ : Builder.getInt32 (1 );
1163+
1164+ // Build kmpc_parallel_51 call
1165+ Value *Parallel51CallArgs[] = {
1166+ /* identifier*/ Ident,
1167+ /* global thread num*/ ThreadID,
1168+ /* if expression */ Cond,
1169+ /* number of threads */ NumThreads ? NumThreads : Builder.getInt32 (-1 ),
1170+ /* Proc bind */ Builder.getInt32 (-1 ),
1171+ /* outlined function */
1172+ Builder.CreateBitCast (&OutlinedFn, OMPIRBuilder->ParallelTaskPtr ),
1173+ /* wrapper function */ NullPtrValue,
1174+ /* arguments of the outlined funciton*/ Args,
1175+ /* number of arguments */ Builder.getInt64 (NumCapturedVars)};
1176+
1177+ FunctionCallee RTLFn =
1178+ OMPIRBuilder->getOrCreateRuntimeFunctionPtr (OMPRTL___kmpc_parallel_51);
1179+
1180+ Builder.CreateCall (RTLFn, Parallel51CallArgs);
1181+
1182+ LLVM_DEBUG (dbgs () << " With kmpc_parallel_51 placed: "
1183+ << *Builder.GetInsertBlock ()->getParent () << " \n " );
1184+
1185+ // Initialize the local TID stack location with the argument value.
1186+ Builder.SetInsertPoint (PrivTID);
1187+ Function::arg_iterator OutlinedAI = OutlinedFn.arg_begin ();
1188+ Builder.CreateStore (Builder.CreateLoad (OMPIRBuilder->Int32 , OutlinedAI),
1189+ PrivTIDAddr);
1190+
1191+ // Remove redundant call to the outlined function.
1192+ CI->eraseFromParent ();
1193+
1194+ for (Instruction *I : ToBeDeleted) {
1195+ I->eraseFromParent ();
1196+ }
1197+ }
1198+
1199+ // Callback used to create OpenMP runtime calls to support
1200+ // omp parallel clause for the host.
1201+ // We need to use this callback to replace call to the OutlinedFn in OuterFn
1202+ // by the call to the OpenMP host runtime function ( __kmpc_fork_call[_if])
1203+ static void
1204+ hostParallelCallback (OpenMPIRBuilder *OMPIRBuilder, Function &OutlinedFn,
1205+ Function *OuterFn, Value *Ident, Value *IfCondition,
1206+ Instruction *PrivTID, AllocaInst *PrivTIDAddr,
1207+ const SmallVector<Instruction *, 4 > &ToBeDeleted) {
1208+ IRBuilder<> &Builder = OMPIRBuilder->Builder ;
1209+ FunctionCallee RTLFn;
1210+ if (IfCondition) {
1211+ RTLFn =
1212+ OMPIRBuilder->getOrCreateRuntimeFunctionPtr (OMPRTL___kmpc_fork_call_if);
1213+ } else {
1214+ RTLFn =
1215+ OMPIRBuilder->getOrCreateRuntimeFunctionPtr (OMPRTL___kmpc_fork_call);
1216+ }
1217+ if (auto *F = dyn_cast<Function>(RTLFn.getCallee ())) {
1218+ if (!F->hasMetadata (LLVMContext::MD_callback)) {
1219+ LLVMContext &Ctx = F->getContext ();
1220+ MDBuilder MDB (Ctx);
1221+ // Annotate the callback behavior of the __kmpc_fork_call:
1222+ // - The callback callee is argument number 2 (microtask).
1223+ // - The first two arguments of the callback callee are unknown (-1).
1224+ // - All variadic arguments to the __kmpc_fork_call are passed to the
1225+ // callback callee.
1226+ F->addMetadata (LLVMContext::MD_callback,
1227+ *MDNode::get (Ctx, {MDB.createCallbackEncoding (
1228+ 2 , {-1 , -1 },
1229+ /* VarArgsArePassed */ true )}));
1230+ }
1231+ }
1232+ // Add some known attributes.
1233+ OutlinedFn.addParamAttr (0 , Attribute::NoAlias);
1234+ OutlinedFn.addParamAttr (1 , Attribute::NoAlias);
1235+ OutlinedFn.addFnAttr (Attribute::NoUnwind);
1236+
1237+ assert (OutlinedFn.arg_size () >= 2 &&
1238+ " Expected at least tid and bounded tid as arguments" );
1239+ unsigned NumCapturedVars = OutlinedFn.arg_size () - /* tid & bounded tid */ 2 ;
1240+
1241+ CallInst *CI = cast<CallInst>(OutlinedFn.user_back ());
1242+ CI->getParent ()->setName (" omp_parallel" );
1243+ Builder.SetInsertPoint (CI);
1244+
1245+ // Build call __kmpc_fork_call[_if](Ident, n, microtask, var1, .., varn);
1246+ Value *ForkCallArgs[] = {
1247+ Ident, Builder.getInt32 (NumCapturedVars),
1248+ Builder.CreateBitCast (&OutlinedFn, OMPIRBuilder->ParallelTaskPtr )};
1249+
1250+ SmallVector<Value *, 16 > RealArgs;
1251+ RealArgs.append (std::begin (ForkCallArgs), std::end (ForkCallArgs));
1252+ if (IfCondition) {
1253+ Value *Cond = Builder.CreateSExtOrTrunc (IfCondition, OMPIRBuilder->Int32 );
1254+ RealArgs.push_back (Cond);
1255+ }
1256+ RealArgs.append (CI->arg_begin () + /* tid & bound tid */ 2 , CI->arg_end ());
1257+
1258+ // __kmpc_fork_call_if always expects a void ptr as the last argument
1259+ // If there are no arguments, pass a null pointer.
1260+ auto PtrTy = OMPIRBuilder->VoidPtr ;
1261+ if (IfCondition && NumCapturedVars == 0 ) {
1262+ Value *NullPtrValue = Constant::getNullValue (PtrTy);
1263+ RealArgs.push_back (NullPtrValue);
1264+ }
1265+ if (IfCondition && RealArgs.back ()->getType () != PtrTy)
1266+ RealArgs.back () = Builder.CreateBitCast (RealArgs.back (), PtrTy);
1267+
1268+ Builder.CreateCall (RTLFn, RealArgs);
1269+
1270+ LLVM_DEBUG (dbgs () << " With fork_call placed: "
1271+ << *Builder.GetInsertBlock ()->getParent () << " \n " );
1272+
1273+ // Initialize the local TID stack location with the argument value.
1274+ Builder.SetInsertPoint (PrivTID);
1275+ Function::arg_iterator OutlinedAI = OutlinedFn.arg_begin ();
1276+ Builder.CreateStore (Builder.CreateLoad (OMPIRBuilder->Int32 , OutlinedAI),
1277+ PrivTIDAddr);
1278+
1279+ // Remove redundant call to the outlined function.
1280+ CI->eraseFromParent ();
1281+
1282+ for (Instruction *I : ToBeDeleted) {
1283+ I->eraseFromParent ();
1284+ }
1285+ }
1286+
11041287IRBuilder<>::InsertPoint OpenMPIRBuilder::createParallel (
11051288 const LocationDescription &Loc, InsertPointTy OuterAllocaIP,
11061289 BodyGenCallbackTy BodyGenCB, PrivatizeCallbackTy PrivCB,
@@ -1115,6 +1298,12 @@ IRBuilder<>::InsertPoint OpenMPIRBuilder::createParallel(
11151298 Constant *SrcLocStr = getOrCreateSrcLocStr (Loc, SrcLocStrSize);
11161299 Value *Ident = getOrCreateIdent (SrcLocStr, SrcLocStrSize);
11171300 Value *ThreadID = getOrCreateThreadID (Ident);
1301+ // If we generate code for the target device, we need to allocate
1302+ // struct for aggregate params in the device default alloca address space.
1303+ // OpenMP runtime requires that the params of the extracted functions are
1304+ // passed as zero address space pointers. This flag ensures that extracted
1305+ // function arguments are declared in zero address space
1306+ bool ArgsInZeroAddressSpace = Config.isTargetDevice ();
11181307
11191308 if (NumThreads) {
11201309 // Build call __kmpc_push_num_threads(&Ident, global_tid, num_threads)
@@ -1148,13 +1337,28 @@ IRBuilder<>::InsertPoint OpenMPIRBuilder::createParallel(
11481337 // Change the location to the outer alloca insertion point to create and
11491338 // initialize the allocas we pass into the parallel region.
11501339 Builder.restoreIP (OuterAllocaIP);
1151- AllocaInst *TIDAddr = Builder.CreateAlloca (Int32, nullptr , " tid.addr" );
1152- AllocaInst *ZeroAddr = Builder.CreateAlloca (Int32, nullptr , " zero.addr" );
1340+ AllocaInst *TIDAddrAlloca = Builder.CreateAlloca (Int32, nullptr , " tid.addr" );
1341+ AllocaInst *ZeroAddrAlloca =
1342+ Builder.CreateAlloca (Int32, nullptr , " zero.addr" );
1343+ Instruction *TIDAddr = TIDAddrAlloca;
1344+ Instruction *ZeroAddr = ZeroAddrAlloca;
1345+ if (ArgsInZeroAddressSpace && M.getDataLayout ().getAllocaAddrSpace () != 0 ) {
1346+ // Add additional casts to enforce pointers in zero address space
1347+ TIDAddr = new AddrSpaceCastInst (
1348+ TIDAddrAlloca, PointerType ::get (M.getContext (), 0 ), " tid.addr.ascast" );
1349+ TIDAddr->insertAfter (TIDAddrAlloca);
1350+ ToBeDeleted.push_back (TIDAddr);
1351+ ZeroAddr = new AddrSpaceCastInst (ZeroAddrAlloca,
1352+ PointerType ::get (M.getContext (), 0 ),
1353+ " zero.addr.ascast" );
1354+ ZeroAddr->insertAfter (ZeroAddrAlloca);
1355+ ToBeDeleted.push_back (ZeroAddr);
1356+ }
11531357
11541358 // We only need TIDAddr and ZeroAddr for modeling purposes to get the
11551359 // associated arguments in the outlined function, so we delete them later.
1156- ToBeDeleted.push_back (TIDAddr );
1157- ToBeDeleted.push_back (ZeroAddr );
1360+ ToBeDeleted.push_back (TIDAddrAlloca );
1361+ ToBeDeleted.push_back (ZeroAddrAlloca );
11581362
11591363 // Create an artificial insertion point that will also ensure the blocks we
11601364 // are about to split are not degenerated.
@@ -1222,87 +1426,24 @@ IRBuilder<>::InsertPoint OpenMPIRBuilder::createParallel(
12221426 BodyGenCB (InnerAllocaIP, CodeGenIP);
12231427
12241428 LLVM_DEBUG (dbgs () << " After body codegen: " << *OuterFn << " \n " );
1225- FunctionCallee RTLFn;
1226- if (IfCondition)
1227- RTLFn = getOrCreateRuntimeFunctionPtr (OMPRTL___kmpc_fork_call_if);
1228- else
1229- RTLFn = getOrCreateRuntimeFunctionPtr (OMPRTL___kmpc_fork_call);
1230-
1231- if (auto *F = dyn_cast<llvm::Function>(RTLFn.getCallee ())) {
1232- if (!F->hasMetadata (llvm::LLVMContext::MD_callback)) {
1233- llvm::LLVMContext &Ctx = F->getContext ();
1234- MDBuilder MDB (Ctx);
1235- // Annotate the callback behavior of the __kmpc_fork_call:
1236- // - The callback callee is argument number 2 (microtask).
1237- // - The first two arguments of the callback callee are unknown (-1).
1238- // - All variadic arguments to the __kmpc_fork_call are passed to the
1239- // callback callee.
1240- F->addMetadata (
1241- llvm::LLVMContext::MD_callback,
1242- *llvm::MDNode::get (
1243- Ctx, {MDB.createCallbackEncoding (2 , {-1 , -1 },
1244- /* VarArgsArePassed */ true )}));
1245- }
1246- }
12471429
12481430 OutlineInfo OI;
1249- OI.PostOutlineCB = [=](Function &OutlinedFn) {
1250- // Add some known attributes.
1251- OutlinedFn.addParamAttr (0 , Attribute::NoAlias);
1252- OutlinedFn.addParamAttr (1 , Attribute::NoAlias);
1253- OutlinedFn.addFnAttr (Attribute::NoUnwind);
1254- OutlinedFn.addFnAttr (Attribute::NoRecurse);
1255-
1256- assert (OutlinedFn.arg_size () >= 2 &&
1257- " Expected at least tid and bounded tid as arguments" );
1258- unsigned NumCapturedVars =
1259- OutlinedFn.arg_size () - /* tid & bounded tid */ 2 ;
1260-
1261- CallInst *CI = cast<CallInst>(OutlinedFn.user_back ());
1262- CI->getParent ()->setName (" omp_parallel" );
1263- Builder.SetInsertPoint (CI);
1264-
1265- // Build call __kmpc_fork_call[_if](Ident, n, microtask, var1, .., varn);
1266- Value *ForkCallArgs[] = {
1267- Ident, Builder.getInt32 (NumCapturedVars),
1268- Builder.CreateBitCast (&OutlinedFn, ParallelTaskPtr)};
1269-
1270- SmallVector<Value *, 16 > RealArgs;
1271- RealArgs.append (std::begin (ForkCallArgs), std::end (ForkCallArgs));
1272- if (IfCondition) {
1273- Value *Cond = Builder.CreateSExtOrTrunc (IfCondition,
1274- Type::getInt32Ty (M.getContext ()));
1275- RealArgs.push_back (Cond);
1276- }
1277- RealArgs.append (CI->arg_begin () + /* tid & bound tid */ 2 , CI->arg_end ());
1278-
1279- // __kmpc_fork_call_if always expects a void ptr as the last argument
1280- // If there are no arguments, pass a null pointer.
1281- auto PtrTy = Type::getInt8PtrTy (M.getContext ());
1282- if (IfCondition && NumCapturedVars == 0 ) {
1283- llvm::Value *Void = ConstantPointerNull::get (PtrTy);
1284- RealArgs.push_back (Void);
1285- }
1286- if (IfCondition && RealArgs.back ()->getType () != PtrTy)
1287- RealArgs.back () = Builder.CreateBitCast (RealArgs.back (), PtrTy);
1288-
1289- Builder.CreateCall (RTLFn, RealArgs);
1290-
1291- LLVM_DEBUG (dbgs () << " With fork_call placed: "
1292- << *Builder.GetInsertBlock ()->getParent () << " \n " );
1293-
1294- InsertPointTy ExitIP (PRegExitBB, PRegExitBB->end ());
1295-
1296- // Initialize the local TID stack location with the argument value.
1297- Builder.SetInsertPoint (PrivTID);
1298- Function::arg_iterator OutlinedAI = OutlinedFn.arg_begin ();
1299- Builder.CreateStore (Builder.CreateLoad (Int32, OutlinedAI), PrivTIDAddr);
1300-
1301- CI->eraseFromParent ();
1302-
1303- for (Instruction *I : ToBeDeleted)
1304- I->eraseFromParent ();
1305- };
1431+ if (Config.isTargetDevice ()) {
1432+ // Generate OpenMP target specific runtime call
1433+ OI.PostOutlineCB = [=, ToBeDeletedVec =
1434+ std::move (ToBeDeleted)](Function &OutlinedFn) {
1435+ targetParallelCallback (this , OutlinedFn, OuterFn, OuterAllocaBlock, Ident,
1436+ IfCondition, NumThreads, PrivTID, PrivTIDAddr,
1437+ ThreadID, ToBeDeletedVec);
1438+ };
1439+ } else {
1440+ // Generate OpenMP host runtime call
1441+ OI.PostOutlineCB = [=, ToBeDeletedVec =
1442+ std::move (ToBeDeleted)](Function &OutlinedFn) {
1443+ hostParallelCallback (this , OutlinedFn, OuterFn, Ident, IfCondition,
1444+ PrivTID, PrivTIDAddr, ToBeDeletedVec);
1445+ };
1446+ }
13061447
13071448 // Adjust the finalization stack, verify the adjustment, and call the
13081449 // finalize function a last time to finalize values between the pre-fini
@@ -1342,7 +1483,7 @@ IRBuilder<>::InsertPoint OpenMPIRBuilder::createParallel(
13421483 /* AllowVarArgs */ true ,
13431484 /* AllowAlloca */ true ,
13441485 /* AllocationBlock */ OuterAllocaBlock,
1345- /* Suffix */ " .omp_par" );
1486+ /* Suffix */ " .omp_par" , ArgsInZeroAddressSpace );
13461487
13471488 // Find inputs to, outputs from the code region.
13481489 BasicBlock *CommonExit = nullptr ;
0 commit comments