@@ -501,15 +501,20 @@ void OpenMPIRBuilder::getKernelArgsVector(TargetKernelArgs &KernelArgs,
501501 Value *ZeroArray = Constant::getNullValue (ArrayType::get (Int32Ty, MaxDim));
502502 Value *Flags = Builder.getInt64 (KernelArgs.HasNoWait );
503503
504- assert (!KernelArgs.NumTeams .empty ());
504+ assert (!KernelArgs.NumTeams .empty () && !KernelArgs. NumThreads . empty () );
505505
506506 Value *NumTeams3D =
507507 Builder.CreateInsertValue (ZeroArray, KernelArgs.NumTeams [0 ], {0 });
508- for (unsigned I = 1 ; I < std::min (KernelArgs.NumTeams .size (), MaxDim); ++I)
508+ Value *NumThreads3D =
509+ Builder.CreateInsertValue (ZeroArray, KernelArgs.NumThreads [0 ], {0 });
510+ for (unsigned I :
511+ seq<unsigned >(1 , std::min (KernelArgs.NumTeams .size (), MaxDim)))
509512 NumTeams3D =
510513 Builder.CreateInsertValue (NumTeams3D, KernelArgs.NumTeams [I], {I});
511- Value *NumThreads3D =
512- Builder.CreateInsertValue (ZeroArray, KernelArgs.NumThreads , {0 });
514+ for (unsigned I :
515+ seq<unsigned >(1 , std::min (KernelArgs.NumThreads .size (), MaxDim)))
516+ NumThreads3D =
517+ Builder.CreateInsertValue (NumThreads3D, KernelArgs.NumThreads [I], {I});
513518
514519 ArgsVector = {Version,
515520 PointerNum,
@@ -1114,9 +1119,9 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::emitKernelLaunch(
11141119 // __tgt_target_teams() launches a GPU kernel with the requested number
11151120 // of teams and threads so no additional calls to the runtime are required.
11161121 // Check the error code and execute the host version if required.
1117- Builder.restoreIP (emitTargetKernel (Builder, AllocaIP, Return, RTLoc, DeviceID,
1118- Args.NumTeams .front (), Args. NumThreads ,
1119- OutlinedFnID, ArgsVector));
1122+ Builder.restoreIP (emitTargetKernel (
1123+ Builder, AllocaIP, Return, RTLoc, DeviceID, Args.NumTeams .front (),
1124+ Args. NumThreads . front (), OutlinedFnID, ArgsVector));
11201125
11211126 BasicBlock *OffloadFailedBlock =
11221127 BasicBlock::Create (Builder.getContext (), " omp_offload.failed" );
@@ -7075,8 +7080,8 @@ void OpenMPIRBuilder::emitOffloadingArraysAndArgs(
70757080static void emitTargetCall (
70767081 OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
70777082 OpenMPIRBuilder::InsertPointTy AllocaIP, Function *OutlinedFn,
7078- Constant *OutlinedFnID, ArrayRef<int32_t > NumTeams, int32_t NumThreads,
7079- SmallVectorImpl<Value *> &Args,
7083+ Constant *OutlinedFnID, ArrayRef<int32_t > NumTeams,
7084+ ArrayRef< int32_t > NumThreads, SmallVectorImpl<Value *> &Args,
70807085 OpenMPIRBuilder::GenMapInfoCallbackTy GenMapInfoCB,
70817086 SmallVector<llvm::OpenMPIRBuilder::DependData> Dependencies = {}) {
70827087 // Generate a function call to the host fallback implementation of the target
@@ -7123,13 +7128,15 @@ static void emitTargetCall(
71237128 /* ForEndCall=*/ false );
71247129
71257130 SmallVector<Value *, 3 > NumTeamsC;
7131+ SmallVector<Value *, 3 > NumThreadsC;
71267132 for (auto V : NumTeams)
71277133 NumTeamsC.push_back (llvm::ConstantInt::get (Builder.getInt32Ty (), V));
7134+ for (auto V : NumThreads)
7135+ NumThreadsC.push_back (llvm::ConstantInt::get (Builder.getInt32Ty (), V));
71287136
71297137 unsigned NumTargetItems = Info.NumberOfPtrs ;
71307138 // TODO: Use correct device ID
71317139 Value *DeviceID = Builder.getInt64 (OMP_DEVICEID_UNDEF);
7132- Value *NumThreadsVal = Builder.getInt32 (NumThreads);
71337140 uint32_t SrcLocStrSize;
71347141 Constant *SrcLocStr = OMPBuilder.getOrCreateDefaultSrcLocStr (SrcLocStrSize);
71357142 Value *RTLoc = OMPBuilder.getOrCreateIdent (SrcLocStr, SrcLocStrSize,
@@ -7140,8 +7147,8 @@ static void emitTargetCall(
71407147 Value *DynCGGroupMem = Builder.getInt32 (0 );
71417148
71427149 OpenMPIRBuilder::TargetKernelArgs KArgs (NumTargetItems, RTArgs, NumIterations,
7143- NumTeamsC, NumThreadsVal ,
7144- DynCGGroupMem, HasNoWait);
7150+ NumTeamsC, NumThreadsC, DynCGGroupMem ,
7151+ HasNoWait);
71457152
71467153 // The presence of certain clauses on the target directive require the
71477154 // explicit generation of the target task.
@@ -7159,11 +7166,11 @@ static void emitTargetCall(
71597166OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createTarget (
71607167 const LocationDescription &Loc, bool IsOffloadEntry, InsertPointTy AllocaIP,
71617168 InsertPointTy CodeGenIP, TargetRegionEntryInfo &EntryInfo,
7162- ArrayRef<int32_t > NumTeams, int32_t NumThreads,
7169+ ArrayRef<int32_t > NumTeams, ArrayRef< int32_t > NumThreads,
71637170 SmallVectorImpl<Value *> &Args, GenMapInfoCallbackTy GenMapInfoCB,
71647171 OpenMPIRBuilder::TargetBodyGenCallbackTy CBFunc,
71657172 OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB,
7166- SmallVector<DependData> Dependenciess ) {
7173+ SmallVector<DependData> Dependencies ) {
71677174
71687175 if (!updateToLocation (Loc))
71697176 return InsertPointTy ();
@@ -7184,7 +7191,7 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createTarget(
71847191 // that represents the target region. Do that now.
71857192 if (!Config.isTargetDevice ())
71867193 emitTargetCall (*this , Builder, AllocaIP, OutlinedFn, OutlinedFnID, NumTeams,
7187- NumThreads, Args, GenMapInfoCB, Dependenciess );
7194+ NumThreads, Args, GenMapInfoCB, Dependencies );
71887195 return Builder.saveIP ();
71897196}
71907197
0 commit comments