@@ -6119,19 +6119,22 @@ CallInst *OpenMPIRBuilder::createCachedThreadPrivate(
61196119 return Builder.CreateCall (Fn, Args);
61206120}
61216121
6122- OpenMPIRBuilder::InsertPointTy
6123- OpenMPIRBuilder::createTargetInit (const LocationDescription &Loc, bool IsSPMD,
6124- int32_t MinThreadsVal, int32_t MaxThreadsVal,
6125- int32_t MinTeamsVal, int32_t MaxTeamsVal) {
6122+ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createTargetInit (
6123+ const LocationDescription &Loc,
6124+ const llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs &Attrs) {
6125+ assert (!Attrs.MaxThreads .empty () && !Attrs.MaxTeams .empty () &&
6126+ " expected num_threads and num_teams to be specified" );
6127+
61266128 if (!updateToLocation (Loc))
61276129 return Loc.IP ;
61286130
61296131 uint32_t SrcLocStrSize;
61306132 Constant *SrcLocStr = getOrCreateSrcLocStr (Loc, SrcLocStrSize);
61316133 Constant *Ident = getOrCreateIdent (SrcLocStr, SrcLocStrSize);
61326134 Constant *IsSPMDVal = ConstantInt::getSigned (
6133- Int8, IsSPMD ? OMP_TGT_EXEC_MODE_SPMD : OMP_TGT_EXEC_MODE_GENERIC);
6134- Constant *UseGenericStateMachineVal = ConstantInt::getSigned (Int8, !IsSPMD);
6135+ Int8, Attrs.IsSPMD ? OMP_TGT_EXEC_MODE_SPMD : OMP_TGT_EXEC_MODE_GENERIC);
6136+ Constant *UseGenericStateMachineVal =
6137+ ConstantInt::getSigned (Int8, !Attrs.IsSPMD );
61356138 Constant *MayUseNestedParallelismVal = ConstantInt::getSigned (Int8, true );
61366139 Constant *DebugIndentionLevelVal = ConstantInt::getSigned (Int16, 0 );
61376140
@@ -6149,21 +6152,23 @@ OpenMPIRBuilder::createTargetInit(const LocationDescription &Loc, bool IsSPMD,
61496152
61506153 // Manifest the launch configuration in the metadata matching the kernel
61516154 // environment.
6152- if (MinTeamsVal > 1 || MaxTeamsVal > 0 )
6153- writeTeamsForKernel (T, *Kernel, MinTeamsVal, MaxTeamsVal );
6155+ if (Attrs. MinTeams > 1 || Attrs. MaxTeams . front () > 0 )
6156+ writeTeamsForKernel (T, *Kernel, Attrs. MinTeams , Attrs. MaxTeams . front () );
61546157
6155- // For max values, < 0 means unset, == 0 means set but unknown.
6158+ // If MaxThreads not set, select the maximum between the default workgroup
6159+ // size and the MinThreads value.
6160+ int32_t MaxThreadsVal = Attrs.MaxThreads .front ();
61566161 if (MaxThreadsVal < 0 )
61576162 MaxThreadsVal = std::max (
6158- int32_t (getGridValue (T, Kernel).GV_Default_WG_Size ), MinThreadsVal );
6163+ int32_t (getGridValue (T, Kernel).GV_Default_WG_Size ), Attrs. MinThreads );
61596164
61606165 if (MaxThreadsVal > 0 )
6161- writeThreadBoundsForKernel (T, *Kernel, MinThreadsVal , MaxThreadsVal);
6166+ writeThreadBoundsForKernel (T, *Kernel, Attrs. MinThreads , MaxThreadsVal);
61626167
6163- Constant *MinThreads = ConstantInt::getSigned (Int32, MinThreadsVal );
6168+ Constant *MinThreads = ConstantInt::getSigned (Int32, Attrs. MinThreads );
61646169 Constant *MaxThreads = ConstantInt::getSigned (Int32, MaxThreadsVal);
6165- Constant *MinTeams = ConstantInt::getSigned (Int32, MinTeamsVal );
6166- Constant *MaxTeams = ConstantInt::getSigned (Int32, MaxTeamsVal );
6170+ Constant *MinTeams = ConstantInt::getSigned (Int32, Attrs. MinTeams );
6171+ Constant *MaxTeams = ConstantInt::getSigned (Int32, Attrs. MaxTeams . front () );
61676172 Constant *ReductionDataSize = ConstantInt::getSigned (Int32, 0 );
61686173 Constant *ReductionBufferLength = ConstantInt::getSigned (Int32, 0 );
61696174
@@ -6730,8 +6735,9 @@ FunctionCallee OpenMPIRBuilder::createDispatchDeinitFunction() {
67306735}
67316736
67326737static Expected<Function *> createOutlinedFunction (
6733- OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, StringRef FuncName,
6734- SmallVectorImpl<Value *> &Inputs,
6738+ OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
6739+ const OpenMPIRBuilder::TargetKernelDefaultAttrs &DefaultAttrs,
6740+ StringRef FuncName, SmallVectorImpl<Value *> &Inputs,
67356741 OpenMPIRBuilder::TargetBodyGenCallbackTy &CBFunc,
67366742 OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy &ArgAccessorFuncCB) {
67376743 SmallVector<Type *> ParameterTypes;
@@ -6798,7 +6804,7 @@ static Expected<Function *> createOutlinedFunction(
67986804
67996805 // Insert target init call in the device compilation pass.
68006806 if (OMPBuilder.Config .isTargetDevice ())
6801- Builder.restoreIP (OMPBuilder.createTargetInit (Builder, /* IsSPMD */ false ));
6807+ Builder.restoreIP (OMPBuilder.createTargetInit (Builder, DefaultAttrs ));
68026808
68036809 BasicBlock *UserCodeEntryBB = Builder.GetInsertBlock ();
68046810
@@ -6997,16 +7003,18 @@ static Function *emitTargetTaskProxyFunction(OpenMPIRBuilder &OMPBuilder,
69977003
69987004static Error emitTargetOutlinedFunction (
69997005 OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, bool IsOffloadEntry,
7000- TargetRegionEntryInfo &EntryInfo, Function *&OutlinedFn,
7001- Constant *&OutlinedFnID, SmallVectorImpl<Value *> &Inputs,
7006+ TargetRegionEntryInfo &EntryInfo,
7007+ const OpenMPIRBuilder::TargetKernelDefaultAttrs &DefaultAttrs,
7008+ Function *&OutlinedFn, Constant *&OutlinedFnID,
7009+ SmallVectorImpl<Value *> &Inputs,
70027010 OpenMPIRBuilder::TargetBodyGenCallbackTy &CBFunc,
70037011 OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy &ArgAccessorFuncCB) {
70047012
70057013 OpenMPIRBuilder::FunctionGenCallback &&GenerateOutlinedFunction =
7006- [&OMPBuilder, &Builder, &Inputs, &CBFunc,
7007- &ArgAccessorFuncCB](StringRef EntryFnName) {
7008- return createOutlinedFunction (OMPBuilder, Builder, EntryFnName, Inputs,
7009- CBFunc, ArgAccessorFuncCB);
7014+ [&](StringRef EntryFnName) {
7015+ return createOutlinedFunction (OMPBuilder, Builder, DefaultAttrs,
7016+ EntryFnName, Inputs, CBFunc ,
7017+ ArgAccessorFuncCB);
70107018 };
70117019
70127020 return OMPBuilder.emitTargetRegionFunction (
@@ -7302,9 +7310,10 @@ void OpenMPIRBuilder::emitOffloadingArraysAndArgs(
73027310
73037311static void
73047312emitTargetCall (OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
7305- OpenMPIRBuilder::InsertPointTy AllocaIP, Function *OutlinedFn,
7306- Constant *OutlinedFnID, ArrayRef<int32_t > NumTeams,
7307- ArrayRef<int32_t > NumThreads, SmallVectorImpl<Value *> &Args,
7313+ OpenMPIRBuilder::InsertPointTy AllocaIP,
7314+ const OpenMPIRBuilder::TargetKernelDefaultAttrs &DefaultAttrs,
7315+ Function *OutlinedFn, Constant *OutlinedFnID,
7316+ SmallVectorImpl<Value *> &Args,
73087317 OpenMPIRBuilder::GenMapInfoCallbackTy GenMapInfoCB,
73097318 SmallVector<llvm::OpenMPIRBuilder::DependData> Dependencies = {},
73107319 bool HasNoWait = false ) {
@@ -7385,9 +7394,9 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
73857394
73867395 SmallVector<Value *, 3 > NumTeamsC;
73877396 SmallVector<Value *, 3 > NumThreadsC;
7388- for (auto V : NumTeams )
7397+ for (auto V : DefaultAttrs. MaxTeams )
73897398 NumTeamsC.push_back (llvm::ConstantInt::get (Builder.getInt32Ty (), V));
7390- for (auto V : NumThreads )
7399+ for (auto V : DefaultAttrs. MaxThreads )
73917400 NumThreadsC.push_back (llvm::ConstantInt::get (Builder.getInt32Ty (), V));
73927401
73937402 unsigned NumTargetItems = Info.NumberOfPtrs ;
@@ -7428,7 +7437,7 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
74287437OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget (
74297438 const LocationDescription &Loc, bool IsOffloadEntry, InsertPointTy AllocaIP,
74307439 InsertPointTy CodeGenIP, TargetRegionEntryInfo &EntryInfo,
7431- ArrayRef< int32_t > NumTeams, ArrayRef< int32_t > NumThreads ,
7440+ const TargetKernelDefaultAttrs &DefaultAttrs ,
74327441 SmallVectorImpl<Value *> &Args, GenMapInfoCallbackTy GenMapInfoCB,
74337442 OpenMPIRBuilder::TargetBodyGenCallbackTy CBFunc,
74347443 OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB,
@@ -7445,16 +7454,16 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget(
74457454 // the target region itself is generated using the callbacks CBFunc
74467455 // and ArgAccessorFuncCB
74477456 if (Error Err = emitTargetOutlinedFunction (
7448- *this , Builder, IsOffloadEntry, EntryInfo, OutlinedFn, OutlinedFnID ,
7449- Args, CBFunc, ArgAccessorFuncCB))
7457+ *this , Builder, IsOffloadEntry, EntryInfo, DefaultAttrs, OutlinedFn ,
7458+ OutlinedFnID, Args, CBFunc, ArgAccessorFuncCB))
74507459 return Err;
74517460
74527461 // If we are not on the target device, then we need to generate code
74537462 // to make a remote call (offload) to the previously outlined function
74547463 // that represents the target region. Do that now.
74557464 if (!Config.isTargetDevice ())
7456- emitTargetCall (*this , Builder, AllocaIP, OutlinedFn, OutlinedFnID, NumTeams ,
7457- NumThreads , Args, GenMapInfoCB, Dependencies, HasNowait);
7465+ emitTargetCall (*this , Builder, AllocaIP, DefaultAttrs, OutlinedFn ,
7466+ OutlinedFnID , Args, GenMapInfoCB, Dependencies, HasNowait);
74587467 return Builder.saveIP ();
74597468}
74607469
0 commit comments