@@ -830,6 +830,38 @@ GlobalValue *OpenMPIRBuilder::createGlobalFlag(unsigned Value, StringRef Name) {
830830 return GV;
831831}
832832
833+ void OpenMPIRBuilder::emitUsed (StringRef Name, ArrayRef<WeakTrackingVH> List) {
834+ if (List.empty ())
835+ return ;
836+
837+ // Convert List to what ConstantArray needs.
838+ SmallVector<Constant *, 8 > UsedArray;
839+ UsedArray.resize (List.size ());
840+ for (unsigned I = 0 , E = List.size (); I != E; ++I)
841+ UsedArray[I] = ConstantExpr::getPointerBitCastOrAddrSpaceCast (
842+ cast<Constant>(&*List[I]), Builder.getPtrTy ());
843+
844+ if (UsedArray.empty ())
845+ return ;
846+ ArrayType *ATy = ArrayType::get (Builder.getPtrTy (), UsedArray.size ());
847+
848+ auto *GV = new GlobalVariable (M, ATy, false , GlobalValue::AppendingLinkage,
849+ ConstantArray::get (ATy, UsedArray), Name);
850+
851+ GV->setSection (" llvm.metadata" );
852+ }
853+
854+ GlobalVariable *
855+ OpenMPIRBuilder::emitKernelExecutionMode (StringRef KernelName,
856+ OMPTgtExecModeFlags Mode) {
857+ auto *Int8Ty = Builder.getInt8Ty ();
858+ auto *GVMode = new GlobalVariable (
859+ M, Int8Ty, /* isConstant=*/ true , GlobalValue::WeakAnyLinkage,
860+ ConstantInt::get (Int8Ty, Mode), Twine (KernelName, " _exec_mode" ));
861+ GVMode->setVisibility (GlobalVariable::ProtectedVisibility);
862+ return GVMode;
863+ }
864+
833865Constant *OpenMPIRBuilder::getOrCreateIdent (Constant *SrcLocStr,
834866 uint32_t SrcLocStrSize,
835867 IdentFlag LocFlags,
@@ -2260,28 +2292,6 @@ static OpenMPIRBuilder::InsertPointTy getInsertPointAfterInstr(Instruction *I) {
22602292 return OpenMPIRBuilder::InsertPointTy (I->getParent (), IT);
22612293}
22622294
2263- void OpenMPIRBuilder::emitUsed (StringRef Name,
2264- std::vector<WeakTrackingVH> &List) {
2265- if (List.empty ())
2266- return ;
2267-
2268- // Convert List to what ConstantArray needs.
2269- SmallVector<Constant *, 8 > UsedArray;
2270- UsedArray.resize (List.size ());
2271- for (unsigned I = 0 , E = List.size (); I != E; ++I)
2272- UsedArray[I] = ConstantExpr::getPointerBitCastOrAddrSpaceCast (
2273- cast<Constant>(&*List[I]), Builder.getPtrTy ());
2274-
2275- if (UsedArray.empty ())
2276- return ;
2277- ArrayType *ATy = ArrayType::get (Builder.getPtrTy (), UsedArray.size ());
2278-
2279- auto *GV = new GlobalVariable (M, ATy, false , GlobalValue::AppendingLinkage,
2280- ConstantArray::get (ATy, UsedArray), Name);
2281-
2282- GV->setSection (" llvm.metadata" );
2283- }
2284-
22852295Value *OpenMPIRBuilder::getGPUThreadID () {
22862296 return Builder.CreateCall (
22872297 getOrCreateRuntimeFunction (M,
@@ -6131,10 +6141,9 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createTargetInit(
61316141 uint32_t SrcLocStrSize;
61326142 Constant *SrcLocStr = getOrCreateSrcLocStr (Loc, SrcLocStrSize);
61336143 Constant *Ident = getOrCreateIdent (SrcLocStr, SrcLocStrSize);
6134- Constant *IsSPMDVal = ConstantInt::getSigned (
6135- Int8, Attrs.IsSPMD ? OMP_TGT_EXEC_MODE_SPMD : OMP_TGT_EXEC_MODE_GENERIC);
6136- Constant *UseGenericStateMachineVal =
6137- ConstantInt::getSigned (Int8, !Attrs.IsSPMD );
6144+ Constant *IsSPMDVal = ConstantInt::getSigned (Int8, Attrs.ExecFlags );
6145+ Constant *UseGenericStateMachineVal = ConstantInt::getSigned (
6146+ Int8, Attrs.ExecFlags != omp::OMP_TGT_EXEC_MODE_SPMD);
61386147 Constant *MayUseNestedParallelismVal = ConstantInt::getSigned (Int8, true );
61396148 Constant *DebugIndentionLevelVal = ConstantInt::getSigned (Int16, 0 );
61406149
@@ -6765,6 +6774,12 @@ static Expected<Function *> createOutlinedFunction(
67656774 auto Func =
67666775 Function::Create (FuncType, GlobalValue::InternalLinkage, FuncName, M);
67676776
6777+ if (OMPBuilder.Config .isTargetDevice ()) {
6778+ Value *ExecMode =
6779+ OMPBuilder.emitKernelExecutionMode (FuncName, DefaultAttrs.ExecFlags );
6780+ OMPBuilder.emitUsed (" llvm.compiler.used" , {ExecMode});
6781+ }
6782+
67686783 // Save insert point.
67696784 IRBuilder<>::InsertPointGuard IPG (Builder);
67706785 // If there's a DISubprogram associated with current function, then
@@ -7312,6 +7327,7 @@ static void
73127327emitTargetCall (OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
73137328 OpenMPIRBuilder::InsertPointTy AllocaIP,
73147329 const OpenMPIRBuilder::TargetKernelDefaultAttrs &DefaultAttrs,
7330+ const OpenMPIRBuilder::TargetKernelRuntimeAttrs &RuntimeAttrs,
73157331 Function *OutlinedFn, Constant *OutlinedFnID,
73167332 SmallVectorImpl<Value *> &Args,
73177333 OpenMPIRBuilder::GenMapInfoCallbackTy GenMapInfoCB,
@@ -7393,11 +7409,43 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
73937409 /* ForEndCall=*/ false );
73947410
73957411 SmallVector<Value *, 3 > NumTeamsC;
7412+ for (auto [DefaultVal, RuntimeVal] :
7413+ zip_equal (DefaultAttrs.MaxTeams , RuntimeAttrs.MaxTeams ))
7414+ NumTeamsC.push_back (RuntimeVal ? RuntimeVal : Builder.getInt32 (DefaultVal));
7415+
7416+ // Calculate number of threads: 0 if no clauses specified, otherwise it is the
7417+ // minimum between optional THREAD_LIMIT and NUM_THREADS clauses.
7418+ auto InitMaxThreadsClause = [&Builder](Value *Clause) {
7419+ if (Clause)
7420+ Clause = Builder.CreateIntCast (Clause, Builder.getInt32Ty (),
7421+ /* isSigned=*/ false );
7422+ return Clause;
7423+ };
7424+ auto CombineMaxThreadsClauses = [&Builder](Value *Clause, Value *&Result) {
7425+ if (Clause)
7426+ Result = Result
7427+ ? Builder.CreateSelect (Builder.CreateICmpULT (Result, Clause),
7428+ Result, Clause)
7429+ : Clause;
7430+ };
7431+
7432+ // If a multi-dimensional THREAD_LIMIT is set, it is the OMPX_BARE case, so
7433+ // the NUM_THREADS clause is overriden by THREAD_LIMIT.
73967434 SmallVector<Value *, 3 > NumThreadsC;
7397- for (auto V : DefaultAttrs.MaxTeams )
7398- NumTeamsC.push_back (llvm::ConstantInt::get (Builder.getInt32Ty (), V));
7399- for (auto V : DefaultAttrs.MaxThreads )
7400- NumThreadsC.push_back (llvm::ConstantInt::get (Builder.getInt32Ty (), V));
7435+ Value *MaxThreadsClause = RuntimeAttrs.TeamsThreadLimit .size () == 1
7436+ ? InitMaxThreadsClause (RuntimeAttrs.MaxThreads )
7437+ : nullptr ;
7438+
7439+ for (auto [TeamsVal, TargetVal] : zip_equal (RuntimeAttrs.TeamsThreadLimit ,
7440+ RuntimeAttrs.TargetThreadLimit )) {
7441+ Value *TeamsThreadLimitClause = InitMaxThreadsClause (TeamsVal);
7442+ Value *NumThreads = InitMaxThreadsClause (TargetVal);
7443+
7444+ CombineMaxThreadsClauses (TeamsThreadLimitClause, NumThreads);
7445+ CombineMaxThreadsClauses (MaxThreadsClause, NumThreads);
7446+
7447+ NumThreadsC.push_back (NumThreads ? NumThreads : Builder.getInt32 (0 ));
7448+ }
74017449
74027450 unsigned NumTargetItems = Info.NumberOfPtrs ;
74037451 // TODO: Use correct device ID
@@ -7406,14 +7454,19 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
74067454 Constant *SrcLocStr = OMPBuilder.getOrCreateDefaultSrcLocStr (SrcLocStrSize);
74077455 Value *RTLoc = OMPBuilder.getOrCreateIdent (SrcLocStr, SrcLocStrSize,
74087456 llvm::omp::IdentFlag (0 ), 0 );
7409- // TODO: Use correct NumIterations
7410- Value *NumIterations = Builder.getInt64 (0 );
7457+
7458+ Value *TripCount = RuntimeAttrs.LoopTripCount
7459+ ? Builder.CreateIntCast (RuntimeAttrs.LoopTripCount ,
7460+ Builder.getInt64Ty (),
7461+ /* isSigned=*/ false )
7462+ : Builder.getInt64 (0 );
7463+
74117464 // TODO: Use correct DynCGGroupMem
74127465 Value *DynCGGroupMem = Builder.getInt32 (0 );
74137466
7414- KArgs = OpenMPIRBuilder::TargetKernelArgs (
7415- NumTargetItems, RTArgs, NumIterations, NumTeamsC, NumThreadsC,
7416- DynCGGroupMem, HasNoWait);
7467+ KArgs = OpenMPIRBuilder::TargetKernelArgs (NumTargetItems, RTArgs, TripCount,
7468+ NumTeamsC, NumThreadsC,
7469+ DynCGGroupMem, HasNoWait);
74177470
74187471 // The presence of certain clauses on the target directive require the
74197472 // explicit generation of the target task.
@@ -7438,6 +7491,7 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget(
74387491 const LocationDescription &Loc, bool IsOffloadEntry, InsertPointTy AllocaIP,
74397492 InsertPointTy CodeGenIP, TargetRegionEntryInfo &EntryInfo,
74407493 const TargetKernelDefaultAttrs &DefaultAttrs,
7494+ const TargetKernelRuntimeAttrs &RuntimeAttrs,
74417495 SmallVectorImpl<Value *> &Args, GenMapInfoCallbackTy GenMapInfoCB,
74427496 OpenMPIRBuilder::TargetBodyGenCallbackTy CBFunc,
74437497 OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB,
@@ -7462,8 +7516,9 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget(
74627516 // to make a remote call (offload) to the previously outlined function
74637517 // that represents the target region. Do that now.
74647518 if (!Config.isTargetDevice ())
7465- emitTargetCall (*this , Builder, AllocaIP, DefaultAttrs, OutlinedFn,
7466- OutlinedFnID, Args, GenMapInfoCB, Dependencies, HasNowait);
7519+ emitTargetCall (*this , Builder, AllocaIP, DefaultAttrs, RuntimeAttrs,
7520+ OutlinedFn, OutlinedFnID, Args, GenMapInfoCB, Dependencies,
7521+ HasNowait);
74677522 return Builder.saveIP ();
74687523}
74697524
0 commit comments