@@ -5308,8 +5308,8 @@ void OpenMPIRBuilder::applySimd(CanonicalLoopInfo *CanonicalLoop,
53085308 Value *Alignment = AlignedItem.second ;
53095309 Instruction *loadInst = dyn_cast<Instruction>(AlignedPtr);
53105310 Builder.SetInsertPoint (loadInst->getNextNode ());
5311- Builder.CreateAlignmentAssumption (F->getDataLayout (),
5312- AlignedPtr, Alignment);
5311+ Builder.CreateAlignmentAssumption (F->getDataLayout (), AlignedPtr,
5312+ Alignment);
53135313 }
53145314 Builder.restoreIP (IP);
53155315 }
@@ -5457,16 +5457,16 @@ static int32_t computeHeuristicUnrollFactor(CanonicalLoopInfo *CLI) {
54575457 Loop *L = LI.getLoopFor (CLI->getHeader ());
54585458 assert (L && " Expecting CanonicalLoopInfo to be recognized as a loop" );
54595459
5460- TargetTransformInfo::UnrollingPreferences UP =
5461- gatherUnrollingPreferences ( L, SE, TTI,
5462- /* BlockFrequencyInfo=*/ nullptr ,
5463- /* ProfileSummaryInfo=*/ nullptr , ORE, static_cast <int >(OptLevel),
5464- /* UserThreshold=*/ std::nullopt ,
5465- /* UserCount=*/ std::nullopt ,
5466- /* UserAllowPartial=*/ true ,
5467- /* UserAllowRuntime=*/ true ,
5468- /* UserUpperBound=*/ std::nullopt ,
5469- /* UserFullUnrollMaxCount=*/ std::nullopt );
5460+ TargetTransformInfo::UnrollingPreferences UP = gatherUnrollingPreferences (
5461+ L, SE, TTI,
5462+ /* BlockFrequencyInfo=*/ nullptr ,
5463+ /* ProfileSummaryInfo=*/ nullptr , ORE, static_cast <int >(OptLevel),
5464+ /* UserThreshold=*/ std::nullopt ,
5465+ /* UserCount=*/ std::nullopt ,
5466+ /* UserAllowPartial=*/ true ,
5467+ /* UserAllowRuntime=*/ true ,
5468+ /* UserUpperBound=*/ std::nullopt ,
5469+ /* UserFullUnrollMaxCount=*/ std::nullopt );
54705470
54715471 UP.Force = true ;
54725472
@@ -7340,7 +7340,7 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
73407340 OpenMPIRBuilder::InsertPointTy AllocaIP,
73417341 const OpenMPIRBuilder::TargetKernelDefaultAttrs &DefaultAttrs,
73427342 const OpenMPIRBuilder::TargetKernelRuntimeAttrs &RuntimeAttrs,
7343- Function *OutlinedFn, Constant *OutlinedFnID,
7343+ Value *IfCond, Function *OutlinedFn, Constant *OutlinedFnID,
73447344 SmallVectorImpl<Value *> &Args,
73457345 OpenMPIRBuilder::GenMapInfoCallbackTy GenMapInfoCB,
73467346 SmallVector<llvm::OpenMPIRBuilder::DependData> Dependencies = {},
@@ -7386,9 +7386,9 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
73867386 return Error::success ();
73877387 };
73887388
7389- // If we don't have an ID for the target region, it means an offload entry
7390- // wasn't created. In this case we just run the host fallback directly.
7391- if (!OutlinedFnID) {
7389+ auto &&EmitTargetCallElse =
7390+ [&](OpenMPIRBuilder::InsertPointTy AllocaIP,
7391+ OpenMPIRBuilder::InsertPointTy CodeGenIP) -> Error {
73927392 // Assume no error was returned because EmitTargetCallFallbackCB doesn't
73937393 // produce any.
73947394 OpenMPIRBuilder::InsertPointTy AfterIP = cantFail ([&]() {
@@ -7404,102 +7404,126 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
74047404 }());
74057405
74067406 Builder.restoreIP (AfterIP);
7407- return ;
7408- }
7409-
7410- OpenMPIRBuilder::TargetDataInfo Info (
7411- /* RequiresDevicePointerInfo=*/ false ,
7412- /* SeparateBeginEndCalls=*/ true );
7413-
7414- OpenMPIRBuilder::MapInfosTy &MapInfo = GenMapInfoCB (Builder.saveIP ());
7415- OpenMPIRBuilder::TargetDataRTArgs RTArgs;
7416- OMPBuilder.emitOffloadingArraysAndArgs (AllocaIP, Builder.saveIP (), Info,
7417- RTArgs, MapInfo,
7418- /* IsNonContiguous=*/ true ,
7419- /* ForEndCall=*/ false );
7420-
7421- SmallVector<Value *, 3 > NumTeamsC;
7422- for (auto [DefaultVal, RuntimeVal] :
7423- zip_equal (DefaultAttrs.MaxTeams , RuntimeAttrs.MaxTeams ))
7424- NumTeamsC.push_back (RuntimeVal ? RuntimeVal : Builder.getInt32 (DefaultVal));
7425-
7426- // Calculate number of threads: 0 if no clauses specified, otherwise it is the
7427- // minimum between optional THREAD_LIMIT and NUM_THREADS clauses.
7428- auto InitMaxThreadsClause = [&Builder](Value *Clause) {
7429- if (Clause)
7430- Clause = Builder.CreateIntCast (Clause, Builder.getInt32Ty (),
7431- /* isSigned=*/ false );
7432- return Clause;
7407+ return Error::success ();
74337408 };
7434- auto CombineMaxThreadsClauses = [&Builder](Value *Clause, Value *&Result) {
7435- if (Clause)
7436- Result = Result
7437- ? Builder.CreateSelect (Builder.CreateICmpULT (Result, Clause),
7409+
7410+ auto &&EmitTargetCallThen =
7411+ [&](OpenMPIRBuilder::InsertPointTy AllocaIP,
7412+ OpenMPIRBuilder::InsertPointTy CodeGenIP) -> Error {
7413+ OpenMPIRBuilder::TargetDataInfo Info (
7414+ /* RequiresDevicePointerInfo=*/ false ,
7415+ /* SeparateBeginEndCalls=*/ true );
7416+
7417+ OpenMPIRBuilder::MapInfosTy &MapInfo = GenMapInfoCB (Builder.saveIP ());
7418+ OpenMPIRBuilder::TargetDataRTArgs RTArgs;
7419+ OMPBuilder.emitOffloadingArraysAndArgs (AllocaIP, Builder.saveIP (), Info,
7420+ RTArgs, MapInfo,
7421+ /* IsNonContiguous=*/ true ,
7422+ /* ForEndCall=*/ false );
7423+
7424+ SmallVector<Value *, 3 > NumTeamsC;
7425+ for (auto [DefaultVal, RuntimeVal] :
7426+ zip_equal (DefaultAttrs.MaxTeams , RuntimeAttrs.MaxTeams ))
7427+ NumTeamsC.push_back (RuntimeVal ? RuntimeVal
7428+ : Builder.getInt32 (DefaultVal));
7429+
7430+ // Calculate number of threads: 0 if no clauses specified, otherwise it is
7431+ // the minimum between optional THREAD_LIMIT and NUM_THREADS clauses.
7432+ auto InitMaxThreadsClause = [&Builder](Value *Clause) {
7433+ if (Clause)
7434+ Clause = Builder.CreateIntCast (Clause, Builder.getInt32Ty (),
7435+ /* isSigned=*/ false );
7436+ return Clause;
7437+ };
7438+ auto CombineMaxThreadsClauses = [&Builder](Value *Clause, Value *&Result) {
7439+ if (Clause)
7440+ Result =
7441+ Result ? Builder.CreateSelect (Builder.CreateICmpULT (Result, Clause),
74387442 Result, Clause)
74397443 : Clause;
7440- };
7444+ };
74417445
7442- // If a multi-dimensional THREAD_LIMIT is set, it is the OMPX_BARE case, so
7443- // the NUM_THREADS clause is overriden by THREAD_LIMIT.
7444- SmallVector<Value *, 3 > NumThreadsC;
7445- Value *MaxThreadsClause = RuntimeAttrs.TeamsThreadLimit .size () == 1
7446- ? InitMaxThreadsClause (RuntimeAttrs.MaxThreads )
7447- : nullptr ;
7446+ // If a multi-dimensional THREAD_LIMIT is set, it is the OMPX_BARE case, so
7447+ // the NUM_THREADS clause is overriden by THREAD_LIMIT.
7448+ SmallVector<Value *, 3 > NumThreadsC;
7449+ Value *MaxThreadsClause =
7450+ RuntimeAttrs.TeamsThreadLimit .size () == 1
7451+ ? InitMaxThreadsClause (RuntimeAttrs.MaxThreads )
7452+ : nullptr ;
74487453
7449- for (auto [TeamsVal, TargetVal] : zip_equal (RuntimeAttrs. TeamsThreadLimit ,
7450- RuntimeAttrs.TargetThreadLimit )) {
7451- Value *TeamsThreadLimitClause = InitMaxThreadsClause (TeamsVal);
7452- Value *NumThreads = InitMaxThreadsClause (TargetVal);
7454+ for (auto [TeamsVal, TargetVal] : zip_equal (
7455+ RuntimeAttrs. TeamsThreadLimit , RuntimeAttrs.TargetThreadLimit )) {
7456+ Value *TeamsThreadLimitClause = InitMaxThreadsClause (TeamsVal);
7457+ Value *NumThreads = InitMaxThreadsClause (TargetVal);
74537458
7454- CombineMaxThreadsClauses (TeamsThreadLimitClause, NumThreads);
7455- CombineMaxThreadsClauses (MaxThreadsClause, NumThreads);
7459+ CombineMaxThreadsClauses (TeamsThreadLimitClause, NumThreads);
7460+ CombineMaxThreadsClauses (MaxThreadsClause, NumThreads);
74567461
7457- NumThreadsC.push_back (NumThreads ? NumThreads : Builder.getInt32 (0 ));
7458- }
7462+ NumThreadsC.push_back (NumThreads ? NumThreads : Builder.getInt32 (0 ));
7463+ }
74597464
7460- unsigned NumTargetItems = Info.NumberOfPtrs ;
7461- // TODO: Use correct device ID
7462- Value *DeviceID = Builder.getInt64 (OMP_DEVICEID_UNDEF);
7463- uint32_t SrcLocStrSize;
7464- Constant *SrcLocStr = OMPBuilder.getOrCreateDefaultSrcLocStr (SrcLocStrSize);
7465- Value *RTLoc = OMPBuilder.getOrCreateIdent (SrcLocStr, SrcLocStrSize,
7466- llvm::omp::IdentFlag (0 ), 0 );
7465+ unsigned NumTargetItems = Info.NumberOfPtrs ;
7466+ // TODO: Use correct device ID
7467+ Value *DeviceID = Builder.getInt64 (OMP_DEVICEID_UNDEF);
7468+ uint32_t SrcLocStrSize;
7469+ Constant *SrcLocStr = OMPBuilder.getOrCreateDefaultSrcLocStr (SrcLocStrSize);
7470+ Value *RTLoc = OMPBuilder.getOrCreateIdent (SrcLocStr, SrcLocStrSize,
7471+ llvm::omp::IdentFlag (0 ), 0 );
74677472
7468- Value *TripCount = RuntimeAttrs.LoopTripCount
7469- ? Builder.CreateIntCast (RuntimeAttrs.LoopTripCount ,
7470- Builder.getInt64Ty (),
7471- /* isSigned=*/ false )
7472- : Builder.getInt64 (0 );
7473+ Value *TripCount = RuntimeAttrs.LoopTripCount
7474+ ? Builder.CreateIntCast (RuntimeAttrs.LoopTripCount ,
7475+ Builder.getInt64Ty (),
7476+ /* isSigned=*/ false )
7477+ : Builder.getInt64 (0 );
74737478
7474- // TODO: Use correct DynCGGroupMem
7475- Value *DynCGGroupMem = Builder.getInt32 (0 );
7479+ // TODO: Use correct DynCGGroupMem
7480+ Value *DynCGGroupMem = Builder.getInt32 (0 );
74767481
7477- KArgs = OpenMPIRBuilder::TargetKernelArgs (NumTargetItems, RTArgs, TripCount,
7478- NumTeamsC, NumThreadsC,
7479- DynCGGroupMem, HasNoWait);
7482+ KArgs = OpenMPIRBuilder::TargetKernelArgs (NumTargetItems, RTArgs, TripCount,
7483+ NumTeamsC, NumThreadsC,
7484+ DynCGGroupMem, HasNoWait);
74807485
7481- // Assume no error was returned because TaskBodyCB and
7482- // EmitTargetCallFallbackCB don't produce any.
7483- OpenMPIRBuilder::InsertPointTy AfterIP = cantFail ([&]() {
7484- // The presence of certain clauses on the target directive require the
7485- // explicit generation of the target task.
7486- if (RequiresOuterTargetTask)
7487- return OMPBuilder.emitTargetTask (TaskBodyCB, DeviceID, RTLoc, AllocaIP,
7488- Dependencies, HasNoWait);
7486+ // Assume no error was returned because TaskBodyCB and
7487+ // EmitTargetCallFallbackCB don't produce any.
7488+ OpenMPIRBuilder::InsertPointTy AfterIP = cantFail ([&]() {
7489+ // The presence of certain clauses on the target directive require the
7490+ // explicit generation of the target task.
7491+ if (RequiresOuterTargetTask)
7492+ return OMPBuilder.emitTargetTask (TaskBodyCB, DeviceID, RTLoc, AllocaIP,
7493+ Dependencies, HasNoWait);
7494+
7495+ return OMPBuilder.emitKernelLaunch (Builder, OutlinedFnID,
7496+ EmitTargetCallFallbackCB, KArgs,
7497+ DeviceID, RTLoc, AllocaIP);
7498+ }());
7499+
7500+ Builder.restoreIP (AfterIP);
7501+ return Error::success ();
7502+ };
74897503
7490- return OMPBuilder.emitKernelLaunch (Builder, OutlinedFnID,
7491- EmitTargetCallFallbackCB, KArgs,
7492- DeviceID, RTLoc, AllocaIP);
7493- }());
7504+ // If we don't have an ID for the target region, it means an offload entry
7505+ // wasn't created. In this case we just run the host fallback directly and
7506+ // ignore any potential 'if' clauses.
7507+ if (!OutlinedFnID) {
7508+ cantFail (EmitTargetCallElse (AllocaIP, Builder.saveIP ()));
7509+ return ;
7510+ }
7511+
7512+ // If there's no 'if' clause, only generate the kernel launch code path.
7513+ if (!IfCond) {
7514+ cantFail (EmitTargetCallThen (AllocaIP, Builder.saveIP ()));
7515+ return ;
7516+ }
74947517
7495- Builder.restoreIP (AfterIP);
7518+ cantFail (OMPBuilder.emitIfClause (IfCond, EmitTargetCallThen,
7519+ EmitTargetCallElse, AllocaIP));
74967520}
74977521
74987522OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget (
74997523 const LocationDescription &Loc, bool IsOffloadEntry, InsertPointTy AllocaIP,
75007524 InsertPointTy CodeGenIP, TargetRegionEntryInfo &EntryInfo,
75017525 const TargetKernelDefaultAttrs &DefaultAttrs,
7502- const TargetKernelRuntimeAttrs &RuntimeAttrs,
7526+ const TargetKernelRuntimeAttrs &RuntimeAttrs, Value *IfCond,
75037527 SmallVectorImpl<Value *> &Args, GenMapInfoCallbackTy GenMapInfoCB,
75047528 OpenMPIRBuilder::TargetBodyGenCallbackTy CBFunc,
75057529 OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB,
@@ -7524,7 +7548,7 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget(
75247548 // to make a remote call (offload) to the previously outlined function
75257549 // that represents the target region. Do that now.
75267550 if (!Config.isTargetDevice ())
7527- emitTargetCall (*this , Builder, AllocaIP, DefaultAttrs, RuntimeAttrs,
7551+ emitTargetCall (*this , Builder, AllocaIP, DefaultAttrs, RuntimeAttrs, IfCond,
75287552 OutlinedFn, OutlinedFnID, Args, GenMapInfoCB, Dependencies,
75297553 HasNowait);
75307554 return Builder.saveIP ();
0 commit comments