@@ -3495,9 +3495,9 @@ checkReductionInfos(ArrayRef<OpenMPIRBuilder::ReductionInfo> ReductionInfos,
34953495OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createReductionsGPU (
34963496 const LocationDescription &Loc, InsertPointTy AllocaIP,
34973497 InsertPointTy CodeGenIP, ArrayRef<ReductionInfo> ReductionInfos,
3498- bool IsNoWait, bool IsTeamsReduction, bool HasDistribute ,
3499- ReductionGenCBKind ReductionGenCBKind, std::optional<omp::GV> GridValue,
3500- unsigned ReductionBufNum, Value *SrcLocInfo) {
3498+ bool IsNoWait, bool IsTeamsReduction, ReductionGenCBKind ReductionGenCBKind ,
3499+ std::optional<omp::GV> GridValue, unsigned ReductionBufNum ,
3500+ Value *SrcLocInfo) {
35013501 if (!updateToLocation (Loc))
35023502 return InsertPointTy ();
35033503 Builder.restoreIP (CodeGenIP);
@@ -3514,6 +3514,16 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createReductionsGPU(
35143514 if (ReductionInfos.size () == 0 )
35153515 return Builder.saveIP ();
35163516
3517+ BasicBlock *ContinuationBlock = nullptr ;
3518+ if (ReductionGenCBKind != ReductionGenCBKind::Clang) {
3519+ // Copied code from createReductions
3520+ BasicBlock *InsertBlock = Loc.IP .getBlock ();
3521+ ContinuationBlock =
3522+ InsertBlock->splitBasicBlock (Loc.IP .getPoint (), " reduce.finalize" );
3523+ InsertBlock->getTerminator ()->eraseFromParent ();
3524+ Builder.SetInsertPoint (InsertBlock, InsertBlock->end ());
3525+ }
3526+
35173527 Function *CurFunc = Builder.GetInsertBlock ()->getParent ();
35183528 AttributeList FuncAttrs;
35193529 AttrBuilder AttrBldr (Ctx);
@@ -3669,11 +3679,21 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createReductionsGPU(
36693679 ReductionFunc;
36703680 });
36713681 } else {
3672- assert (false && " Unhandled ReductionGenCBKind" );
3682+ Value *LHSValue = Builder.CreateLoad (RI.ElementType , LHS, " final.lhs" );
3683+ Value *RHSValue = Builder.CreateLoad (RI.ElementType , RHS, " final.rhs" );
3684+ Value *Reduced;
3685+ InsertPointOrErrorTy AfterIP =
3686+ RI.ReductionGen (Builder.saveIP (), RHSValue, LHSValue, Reduced);
3687+ if (!AfterIP)
3688+ return AfterIP.takeError ();
3689+ Builder.CreateStore (Reduced, LHS, false );
36733690 }
36743691 }
36753692 emitBlock (ExitBB, CurFunc);
3676-
3693+ if (ContinuationBlock) {
3694+ Builder.CreateBr (ContinuationBlock);
3695+ Builder.SetInsertPoint (ContinuationBlock);
3696+ }
36773697 Config.setEmitLLVMUsed ();
36783698
36793699 return Builder.saveIP ();
@@ -3688,27 +3708,95 @@ static Function *getFreshReductionFunc(Module &M) {
36883708 " .omp.reduction.func" , &M);
36893709}
36903710
3691- OpenMPIRBuilder::InsertPointOrErrorTy
3692- OpenMPIRBuilder::createReductions (const LocationDescription &Loc,
3693- InsertPointTy AllocaIP,
3694- ArrayRef<ReductionInfo> ReductionInfos,
3695- ArrayRef<bool > IsByRef, bool IsNoWait) {
3696- assert (ReductionInfos.size () == IsByRef.size ());
3697- for (const ReductionInfo &RI : ReductionInfos) {
3698- (void )RI;
3699- assert (RI.Variable && " expected non-null variable" );
3700- assert (RI.PrivateVariable && " expected non-null private variable" );
3701- assert (RI.ReductionGen && " expected non-null reduction generator callback" );
3702- assert (RI.Variable ->getType () == RI.PrivateVariable ->getType () &&
3703- " expected variables and their private equivalents to have the same "
3704- " type" );
3705- assert (RI.Variable ->getType ()->isPointerTy () &&
3706- " expected variables to be pointers" );
3711+ static Error populateReductionFunction (
3712+ Function *ReductionFunc,
3713+ ArrayRef<OpenMPIRBuilder::ReductionInfo> ReductionInfos,
3714+ IRBuilder<> &Builder, ArrayRef<bool > IsByRef, bool IsGPU) {
3715+ Module *Module = ReductionFunc->getParent ();
3716+ BasicBlock *ReductionFuncBlock =
3717+ BasicBlock::Create (Module->getContext (), " " , ReductionFunc);
3718+ Builder.SetInsertPoint (ReductionFuncBlock);
3719+ Value *LHSArrayPtr = nullptr ;
3720+ Value *RHSArrayPtr = nullptr ;
3721+ if (IsGPU) {
3722+ // Need to alloca memory here and deal with the pointers before getting
3723+ // LHS/RHS pointers out
3724+ //
3725+ Argument *Arg0 = ReductionFunc->getArg (0 );
3726+ Argument *Arg1 = ReductionFunc->getArg (1 );
3727+ Type *Arg0Type = Arg0->getType ();
3728+ Type *Arg1Type = Arg1->getType ();
3729+
3730+ Value *LHSAlloca =
3731+ Builder.CreateAlloca (Arg0Type, nullptr , Arg0->getName () + " .addr" );
3732+ Value *RHSAlloca =
3733+ Builder.CreateAlloca (Arg1Type, nullptr , Arg1->getName () + " .addr" );
3734+ Value *LHSAddrCast =
3735+ Builder.CreatePointerBitCastOrAddrSpaceCast (LHSAlloca, Arg0Type);
3736+ Value *RHSAddrCast =
3737+ Builder.CreatePointerBitCastOrAddrSpaceCast (RHSAlloca, Arg1Type);
3738+ Builder.CreateStore (Arg0, LHSAddrCast);
3739+ Builder.CreateStore (Arg1, RHSAddrCast);
3740+ LHSArrayPtr = Builder.CreateLoad (Arg0Type, LHSAddrCast);
3741+ RHSArrayPtr = Builder.CreateLoad (Arg1Type, RHSAddrCast);
3742+ } else {
3743+ LHSArrayPtr = ReductionFunc->getArg (0 );
3744+ RHSArrayPtr = ReductionFunc->getArg (1 );
37073745 }
37083746
3747+ unsigned NumReductions = ReductionInfos.size ();
3748+ Type *RedArrayTy = ArrayType::get (Builder.getPtrTy (), NumReductions);
3749+
3750+ for (auto En : enumerate(ReductionInfos)) {
3751+ const OpenMPIRBuilder::ReductionInfo &RI = En.value ();
3752+ Value *LHSI8PtrPtr = Builder.CreateConstInBoundsGEP2_64 (
3753+ RedArrayTy, LHSArrayPtr, 0 , En.index ());
3754+ Value *LHSI8Ptr = Builder.CreateLoad (Builder.getPtrTy (), LHSI8PtrPtr);
3755+ Value *LHSPtr = Builder.CreatePointerBitCastOrAddrSpaceCast (
3756+ LHSI8Ptr, RI.Variable ->getType ());
3757+ Value *LHS = Builder.CreateLoad (RI.ElementType , LHSPtr);
3758+ Value *RHSI8PtrPtr = Builder.CreateConstInBoundsGEP2_64 (
3759+ RedArrayTy, RHSArrayPtr, 0 , En.index ());
3760+ Value *RHSI8Ptr = Builder.CreateLoad (Builder.getPtrTy (), RHSI8PtrPtr);
3761+ Value *RHSPtr = Builder.CreatePointerBitCastOrAddrSpaceCast (
3762+ RHSI8Ptr, RI.PrivateVariable ->getType ());
3763+ Value *RHS = Builder.CreateLoad (RI.ElementType , RHSPtr);
3764+ Value *Reduced;
3765+ OpenMPIRBuilder::InsertPointOrErrorTy AfterIP =
3766+ RI.ReductionGen (Builder.saveIP (), LHS, RHS, Reduced);
3767+ if (!AfterIP)
3768+ return AfterIP.takeError ();
3769+
3770+ Builder.restoreIP (*AfterIP);
3771+ // TODO: Consider flagging an error.
3772+ if (!Builder.GetInsertBlock ())
3773+ return Error::success ();
3774+
3775+ // store is inside of the reduction region when using by-ref
3776+ if (!IsByRef[En.index ()])
3777+ Builder.CreateStore (Reduced, LHSPtr);
3778+ }
3779+ Builder.CreateRetVoid ();
3780+ return Error::success ();
3781+ }
3782+
3783+ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createReductions (
3784+ const LocationDescription &Loc, InsertPointTy AllocaIP,
3785+ ArrayRef<ReductionInfo> ReductionInfos, ArrayRef<bool > IsByRef,
3786+ bool IsNoWait, bool IsTeamsReduction) {
3787+ assert (ReductionInfos.size () == IsByRef.size ());
3788+ if (Config.isGPU ())
3789+ return createReductionsGPU (Loc, AllocaIP, Builder.saveIP (), ReductionInfos,
3790+ IsNoWait, IsTeamsReduction);
3791+
3792+ checkReductionInfos (ReductionInfos, /* IsGPU*/ false );
3793+
37093794 if (!updateToLocation (Loc))
37103795 return InsertPointTy ();
37113796
3797+ if (ReductionInfos.size () == 0 )
3798+ return Builder.saveIP ();
3799+
37123800 BasicBlock *InsertBlock = Loc.IP .getBlock ();
37133801 BasicBlock *ContinuationBlock =
37143802 InsertBlock->splitBasicBlock (Loc.IP .getPoint (), " reduce.finalize" );
@@ -3832,38 +3920,13 @@ OpenMPIRBuilder::createReductions(const LocationDescription &Loc,
38323920 // Populate the outlined reduction function using the elementwise reduction
38333921 // function. Partial values are extracted from the type-erased array of
38343922 // pointers to private variables.
3835- BasicBlock *ReductionFuncBlock =
3836- BasicBlock::Create (Module->getContext (), " " , ReductionFunc);
3837- Builder.SetInsertPoint (ReductionFuncBlock);
3838- Value *LHSArrayPtr = ReductionFunc->getArg (0 );
3839- Value *RHSArrayPtr = ReductionFunc->getArg (1 );
3923+ Error Err = populateReductionFunction (ReductionFunc, ReductionInfos, Builder,
3924+ IsByRef, /* isGPU=*/ false );
3925+ if (Err)
3926+ return Err;
38403927
3841- for (auto En : enumerate(ReductionInfos)) {
3842- const ReductionInfo &RI = En.value ();
3843- Value *LHSI8PtrPtr = Builder.CreateConstInBoundsGEP2_64 (
3844- RedArrayTy, LHSArrayPtr, 0 , En.index ());
3845- Value *LHSI8Ptr = Builder.CreateLoad (Builder.getPtrTy (), LHSI8PtrPtr);
3846- Value *LHSPtr = Builder.CreateBitCast (LHSI8Ptr, RI.Variable ->getType ());
3847- Value *LHS = Builder.CreateLoad (RI.ElementType , LHSPtr);
3848- Value *RHSI8PtrPtr = Builder.CreateConstInBoundsGEP2_64 (
3849- RedArrayTy, RHSArrayPtr, 0 , En.index ());
3850- Value *RHSI8Ptr = Builder.CreateLoad (Builder.getPtrTy (), RHSI8PtrPtr);
3851- Value *RHSPtr =
3852- Builder.CreateBitCast (RHSI8Ptr, RI.PrivateVariable ->getType ());
3853- Value *RHS = Builder.CreateLoad (RI.ElementType , RHSPtr);
3854- Value *Reduced;
3855- InsertPointOrErrorTy AfterIP =
3856- RI.ReductionGen (Builder.saveIP (), LHS, RHS, Reduced);
3857- if (!AfterIP)
3858- return AfterIP.takeError ();
3859- Builder.restoreIP (*AfterIP);
3860- if (!Builder.GetInsertBlock ())
3861- return InsertPointTy ();
3862- // store is inside of the reduction region when using by-ref
3863- if (!IsByRef[En.index ()])
3864- Builder.CreateStore (Reduced, LHSPtr);
3865- }
3866- Builder.CreateRetVoid ();
3928+ if (!Builder.GetInsertBlock ())
3929+ return InsertPointTy ();
38673930
38683931 Builder.SetInsertPoint (ContinuationBlock);
38693932 return Builder.saveIP ();
@@ -6239,8 +6302,10 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createTargetInit(
62396302 Constant *MaxThreads = ConstantInt::getSigned (Int32, MaxThreadsVal);
62406303 Constant *MinTeams = ConstantInt::getSigned (Int32, Attrs.MinTeams );
62416304 Constant *MaxTeams = ConstantInt::getSigned (Int32, Attrs.MaxTeams .front ());
6242- Constant *ReductionDataSize = ConstantInt::getSigned (Int32, 0 );
6243- Constant *ReductionBufferLength = ConstantInt::getSigned (Int32, 0 );
6305+ Constant *ReductionDataSize =
6306+ ConstantInt::getSigned (Int32, Attrs.ReductionDataSize );
6307+ Constant *ReductionBufferLength =
6308+ ConstantInt::getSigned (Int32, Attrs.ReductionBufferLength );
62446309
62456310 Function *Fn = getOrCreateRuntimeFunctionPtr (
62466311 omp::RuntimeFunction::OMPRTL___kmpc_target_init);
0 commit comments