@@ -6968,8 +6968,8 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::emitTargetTask(
69686968 }
69696969
69706970 OI.ExitBB = Builder.saveIP ().getBlock ();
6971- OI.PostOutlineCB = [this , ToBeDeleted, Dependencies,
6972- HasNoWait ](Function &OutlinedFn) mutable {
6971+ OI.PostOutlineCB = [this , ToBeDeleted, Dependencies, HasNoWait,
6972+ DeviceID ](Function &OutlinedFn) mutable {
69736973 assert (OutlinedFn.getNumUses () == 1 &&
69746974 " there must be a single user for the outlined function" );
69756975
@@ -6989,9 +6989,15 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::emitTargetTask(
69896989 getOrCreateSrcLocStr (LocationDescription (Builder), SrcLocStrSize);
69906990 Value *Ident = getOrCreateIdent (SrcLocStr, SrcLocStrSize);
69916991
6992- // @__kmpc_omp_task_alloc
6992+ // @__kmpc_omp_task_alloc or @__kmpc_omp_target_task_alloc
6993+ //
6994+ // If `HasNoWait == true`, we call @__kmpc_omp_target_task_alloc to provide
6995+ // the DeviceID to the deferred task and also since
6996+ // @__kmpc_omp_target_task_alloc creates an untied/async task.
69936997 Function *TaskAllocFn =
6994- getOrCreateRuntimeFunctionPtr (OMPRTL___kmpc_omp_task_alloc);
6998+ !HasNoWait ? getOrCreateRuntimeFunctionPtr (OMPRTL___kmpc_omp_task_alloc)
6999+ : getOrCreateRuntimeFunctionPtr (
7000+ OMPRTL___kmpc_omp_target_task_alloc);
69957001
69967002 // Arguments - `loc_ref` (Ident) and `gtid` (ThreadID)
69977003 // call.
@@ -7032,10 +7038,18 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::emitTargetTask(
70327038 // Emit the @__kmpc_omp_task_alloc runtime call
70337039 // The runtime call returns a pointer to an area where the task captured
70347040 // variables must be copied before the task is run (TaskData)
7035- CallInst *TaskData = Builder.CreateCall (
7036- TaskAllocFn, {/* loc_ref=*/ Ident, /* gtid=*/ ThreadID, /* flags=*/ Flags,
7037- /* sizeof_task=*/ TaskSize, /* sizeof_shared=*/ SharedsSize,
7038- /* task_func=*/ ProxyFn});
7041+ CallInst *TaskData = nullptr ;
7042+
7043+ SmallVector<llvm::Value *> TaskAllocArgs = {
7044+ /* loc_ref=*/ Ident, /* gtid=*/ ThreadID,
7045+ /* flags=*/ Flags,
7046+ /* sizeof_task=*/ TaskSize, /* sizeof_shared=*/ SharedsSize,
7047+ /* task_func=*/ ProxyFn};
7048+
7049+ if (HasNoWait)
7050+ TaskAllocArgs.push_back (DeviceID);
7051+
7052+ TaskData = Builder.CreateCall (TaskAllocFn, TaskAllocArgs);
70397053
70407054 if (HasShareds) {
70417055 Value *Shareds = StaleCI->getArgOperand (1 );
@@ -7118,13 +7132,14 @@ void OpenMPIRBuilder::emitOffloadingArraysAndArgs(
71187132 emitOffloadingArraysArgument (Builder, RTArgs, Info, ForEndCall);
71197133}
71207134
7121- static void emitTargetCall (
7122- OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
7123- OpenMPIRBuilder::InsertPointTy AllocaIP, Function *OutlinedFn,
7124- Constant *OutlinedFnID, ArrayRef<int32_t > NumTeams,
7125- ArrayRef<int32_t > NumThreads, SmallVectorImpl<Value *> &Args,
7126- OpenMPIRBuilder::GenMapInfoCallbackTy GenMapInfoCB,
7127- SmallVector<llvm::OpenMPIRBuilder::DependData> Dependencies = {}) {
7135+ static void
7136+ emitTargetCall (OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
7137+ OpenMPIRBuilder::InsertPointTy AllocaIP, Function *OutlinedFn,
7138+ Constant *OutlinedFnID, ArrayRef<int32_t > NumTeams,
7139+ ArrayRef<int32_t > NumThreads, SmallVectorImpl<Value *> &Args,
7140+ OpenMPIRBuilder::GenMapInfoCallbackTy GenMapInfoCB,
7141+ SmallVector<llvm::OpenMPIRBuilder::DependData> Dependencies = {},
7142+ bool HasNoWait = false ) {
71287143 // Generate a function call to the host fallback implementation of the target
71297144 // region. This is called by the host when no offload entry was generated for
71307145 // the target region and when the offloading call fails at runtime.
@@ -7135,7 +7150,6 @@ static void emitTargetCall(
71357150 return Builder.saveIP ();
71367151 };
71377152
7138- bool HasNoWait = false ;
71397153 bool HasDependencies = Dependencies.size () > 0 ;
71407154 bool RequiresOuterTargetTask = HasNoWait || HasDependencies;
71417155
@@ -7211,7 +7225,7 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createTarget(
72117225 SmallVectorImpl<Value *> &Args, GenMapInfoCallbackTy GenMapInfoCB,
72127226 OpenMPIRBuilder::TargetBodyGenCallbackTy CBFunc,
72137227 OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB,
7214- SmallVector<DependData> Dependencies) {
7228+ SmallVector<DependData> Dependencies, bool HasNowait ) {
72157229
72167230 if (!updateToLocation (Loc))
72177231 return InsertPointTy ();
@@ -7232,7 +7246,7 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createTarget(
72327246 // that represents the target region. Do that now.
72337247 if (!Config.isTargetDevice ())
72347248 emitTargetCall (*this , Builder, AllocaIP, OutlinedFn, OutlinedFnID, NumTeams,
7235- NumThreads, Args, GenMapInfoCB, Dependencies);
7249+ NumThreads, Args, GenMapInfoCB, Dependencies, HasNowait );
72367250 return Builder.saveIP ();
72377251}
72387252
0 commit comments