@@ -4103,94 +4103,111 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::emitScanReduction(
41034103 const LocationDescription &Loc, InsertPointTy &FinalizeIP,
41044104 SmallVector<llvm::OpenMPIRBuilder::ReductionInfo> ReductionInfos) {
41054105
4106- llvm::Value *spanDiff = ScanInfo.Span ;
4107-
41084106 if (!updateToLocation (Loc))
41094107 return Loc.IP ;
4110- auto curFn = Builder.GetInsertBlock ()->getParent ();
4111- // for (int k = 0; k <= ceil(log2(n)); ++k)
4112- llvm::BasicBlock *LoopBB =
4113- BasicBlock::Create (curFn->getContext (), " omp.outer.log.scan.body" );
4114- llvm::BasicBlock *ExitBB =
4115- BasicBlock::Create (curFn->getContext (), " omp.outer.log.scan.exit" );
4116- llvm::Function *F = llvm::Intrinsic::getOrInsertDeclaration (
4117- Builder.GetInsertBlock ()->getModule (),
4118- (llvm::Intrinsic::ID)llvm::Intrinsic::log2, Builder.getDoubleTy ());
4119- llvm::BasicBlock *InputBB = Builder.GetInsertBlock ();
4120- ConstantInt *One = ConstantInt::get (Builder.getInt32Ty (), 1 );
4121- llvm::Value *span = ScanInfo.Span ; // Builder.CreateAdd(spanDiff, One);
4122- llvm::Value *Arg = Builder.CreateUIToFP (span, Builder.getDoubleTy ());
4123- llvm::Value *LogVal = emitNoUnwindRuntimeCall (F, Arg, " " );
4124- F = llvm::Intrinsic::getOrInsertDeclaration (
4125- Builder.GetInsertBlock ()->getModule (),
4126- (llvm::Intrinsic::ID)llvm::Intrinsic::ceil, Builder.getDoubleTy ());
4127- LogVal = emitNoUnwindRuntimeCall (F, LogVal, " " );
4128- LogVal = Builder.CreateFPToUI (LogVal, Builder.getInt32Ty ());
4129- llvm::Value *NMin1 =
4130- Builder.CreateNUWSub (span, llvm::ConstantInt::get (span->getType (), 1 ));
4131- Builder.SetInsertPoint (InputBB);
4132- Builder.CreateBr (LoopBB);
4133- emitBlock (LoopBB, Builder.GetInsertBlock ()->getParent ());
4134- Builder.SetInsertPoint (LoopBB);
4135-
4136- PHINode *Counter = Builder.CreatePHI (Builder.getInt32Ty (), 2 );
4137- // // size pow2k = 1;
4138- PHINode *Pow2K = Builder.CreatePHI (Builder.getInt32Ty (), 2 );
4139- Counter->addIncoming (llvm::ConstantInt::get (Builder.getInt32Ty (), 0 ),
4108+ auto BodyGenCB = [&](InsertPointTy AllocaIP,
4109+ InsertPointTy CodeGenIP) -> Error {
4110+ Builder.restoreIP (CodeGenIP);
4111+ auto CurFn = Builder.GetInsertBlock ()->getParent ();
4112+ // for (int k = 0; k <= ceil(log2(n)); ++k)
4113+ llvm::BasicBlock *LoopBB =
4114+ BasicBlock::Create (CurFn->getContext (), " omp.outer.log.scan.body" );
4115+ llvm::BasicBlock *ExitBB =
4116+ splitBB (Builder, false , " omp.outer.log.scan.exit" );
4117+ llvm::Function *F = llvm::Intrinsic::getOrInsertDeclaration (
4118+ Builder.GetInsertBlock ()->getModule (),
4119+ (llvm::Intrinsic::ID)llvm::Intrinsic::log2, Builder.getDoubleTy ());
4120+ llvm::BasicBlock *InputBB = Builder.GetInsertBlock ();
4121+ llvm::Value *Arg =
4122+ Builder.CreateUIToFP (ScanInfo.Span , Builder.getDoubleTy ());
4123+ llvm::Value *LogVal = emitNoUnwindRuntimeCall (F, Arg, " " );
4124+ F = llvm::Intrinsic::getOrInsertDeclaration (
4125+ Builder.GetInsertBlock ()->getModule (),
4126+ (llvm::Intrinsic::ID)llvm::Intrinsic::ceil, Builder.getDoubleTy ());
4127+ LogVal = emitNoUnwindRuntimeCall (F, LogVal, " " );
4128+ LogVal = Builder.CreateFPToUI (LogVal, Builder.getInt32Ty ());
4129+ llvm::Value *NMin1 = Builder.CreateNUWSub (
4130+ ScanInfo.Span , llvm::ConstantInt::get (ScanInfo.Span ->getType (), 1 ));
4131+ Builder.SetInsertPoint (InputBB);
4132+ Builder.CreateBr (LoopBB);
4133+ emitBlock (LoopBB, Builder.GetInsertBlock ()->getParent ());
4134+ Builder.SetInsertPoint (LoopBB);
4135+
4136+ PHINode *Counter = Builder.CreatePHI (Builder.getInt32Ty (), 2 );
4137+ // // size pow2k = 1;
4138+ PHINode *Pow2K = Builder.CreatePHI (Builder.getInt32Ty (), 2 );
4139+ Counter->addIncoming (llvm::ConstantInt::get (Builder.getInt32Ty (), 0 ),
4140+ InputBB);
4141+ Pow2K->addIncoming (llvm::ConstantInt::get (Builder.getInt32Ty (), 1 ),
41404142 InputBB);
4141- Pow2K->addIncoming (llvm::ConstantInt::get (Builder.getInt32Ty (), 1 ), InputBB);
4142- // // for (size i = n - 1; i >= 2 ^ k; --i)
4143- // // tmp[i] op= tmp[i-pow2k];
4144- llvm::BasicBlock *InnerLoopBB =
4145- BasicBlock::Create (curFn->getContext (), " omp.inner.log.scan.body" );
4146- llvm::BasicBlock *InnerExitBB =
4147- BasicBlock::Create (curFn->getContext (), " omp.inner.log.scan.exit" );
4148- llvm::Value *CmpI = Builder.CreateICmpUGE (NMin1, Pow2K);
4149- Builder.CreateCondBr (CmpI, InnerLoopBB, InnerExitBB);
4150- emitBlock (InnerLoopBB, Builder.GetInsertBlock ()->getParent ());
4151- Builder.SetInsertPoint (InnerLoopBB);
4152- auto *IVal = Builder.CreatePHI (Builder.getInt32Ty (), 2 );
4153- IVal->addIncoming (NMin1, LoopBB);
4154- unsigned int defaultAS = M.getDataLayout ().getProgramAddressSpace ();
4155- for (ReductionInfo RedInfo : ReductionInfos) {
4156- Value *ReductionVal = RedInfo.PrivateVariable ;
4157- Value *Buff = ScanInfo.ReductionVarToScanBuffs [ReductionVal];
4158- Type *DestTy = RedInfo.ElementType ;
4159- Value *IV = Builder.CreateAdd (IVal, Builder.getInt32 (1 ));
4160- Value *LHSPtr = Builder.CreateInBoundsGEP (DestTy, Buff, IV, " arrayOffset" );
4161- Value *OffsetIval = Builder.CreateNUWSub (IV, Pow2K);
4162- Value *RHSPtr =
4163- Builder.CreateInBoundsGEP (DestTy, Buff, OffsetIval, " arrayOffset" );
4164- Value *LHS = Builder.CreateLoad (DestTy, LHSPtr);
4165- Value *RHS = Builder.CreateLoad (DestTy, RHSPtr);
4166- Value *LHSAddr = Builder.CreatePointerBitCastOrAddrSpaceCast (
4167- LHSPtr, RHS->getType ()->getPointerTo (defaultAS));
4168- llvm::Value *Result;
4169- InsertPointOrErrorTy AfterIP =
4170- RedInfo.ReductionGen (Builder.saveIP (), LHS, RHS, Result);
4171- if (!AfterIP)
4172- return AfterIP.takeError ();
4173- Builder.CreateStore (Result, LHSAddr);
4174- }
4175- llvm::Value *NextIVal = Builder.CreateNUWSub (
4176- IVal, llvm::ConstantInt::get (Builder.getInt32Ty (), 1 ));
4177- IVal->addIncoming (NextIVal, Builder.GetInsertBlock ());
4178- CmpI = Builder.CreateICmpUGE (NextIVal, Pow2K);
4179- Builder.CreateCondBr (CmpI, InnerLoopBB, InnerExitBB);
4180- emitBlock (InnerExitBB, Builder.GetInsertBlock ()->getParent ());
4181- llvm::Value *Next = Builder.CreateNUWAdd (
4182- Counter, llvm::ConstantInt::get (Counter->getType (), 1 ));
4183- Counter->addIncoming (Next, Builder.GetInsertBlock ());
4184- // pow2k <<= 1;
4185- llvm::Value *NextPow2K = Builder.CreateShl (Pow2K, 1 , " " , /* HasNUW=*/ true );
4186- Pow2K->addIncoming (NextPow2K, Builder.GetInsertBlock ());
4187- llvm::Value *Cmp = Builder.CreateICmpNE (Next, LogVal);
4188- Builder.CreateCondBr (Cmp, LoopBB, ExitBB);
4189- emitBlock (ExitBB, Builder.GetInsertBlock ()->getParent ());
4190- Builder.SetInsertPoint (ExitBB);
4143+ // // for (size i = n - 1; i >= 2 ^ k; --i)
4144+ // // tmp[i] op= tmp[i-pow2k];
4145+ llvm::BasicBlock *InnerLoopBB =
4146+ BasicBlock::Create (CurFn->getContext (), " omp.inner.log.scan.body" );
4147+ llvm::BasicBlock *InnerExitBB =
4148+ BasicBlock::Create (CurFn->getContext (), " omp.inner.log.scan.exit" );
4149+ llvm::Value *CmpI = Builder.CreateICmpUGE (NMin1, Pow2K);
4150+ Builder.CreateCondBr (CmpI, InnerLoopBB, InnerExitBB);
4151+ emitBlock (InnerLoopBB, Builder.GetInsertBlock ()->getParent ());
4152+ Builder.SetInsertPoint (InnerLoopBB);
4153+ auto *IVal = Builder.CreatePHI (Builder.getInt32Ty (), 2 );
4154+ IVal->addIncoming (NMin1, LoopBB);
4155+ unsigned int defaultAS = M.getDataLayout ().getProgramAddressSpace ();
4156+ for (ReductionInfo RedInfo : ReductionInfos) {
4157+ Value *ReductionVal = RedInfo.PrivateVariable ;
4158+ Value *Buff = ScanInfo.ReductionVarToScanBuffs [ReductionVal];
4159+ Type *DestTy = RedInfo.ElementType ;
4160+ Value *IV = Builder.CreateAdd (IVal, Builder.getInt32 (1 ));
4161+ Value *LHSPtr =
4162+ Builder.CreateInBoundsGEP (DestTy, Buff, IV, " arrayOffset" );
4163+ Value *OffsetIval = Builder.CreateNUWSub (IV, Pow2K);
4164+ Value *RHSPtr =
4165+ Builder.CreateInBoundsGEP (DestTy, Buff, OffsetIval, " arrayOffset" );
4166+ Value *LHS = Builder.CreateLoad (DestTy, LHSPtr);
4167+ Value *RHS = Builder.CreateLoad (DestTy, RHSPtr);
4168+ Value *LHSAddr = Builder.CreatePointerBitCastOrAddrSpaceCast (
4169+ LHSPtr, RHS->getType ()->getPointerTo (defaultAS));
4170+ llvm::Value *Result;
4171+ InsertPointOrErrorTy AfterIP =
4172+ RedInfo.ReductionGen (Builder.saveIP (), LHS, RHS, Result);
4173+ if (!AfterIP)
4174+ return AfterIP.takeError ();
4175+ Builder.CreateStore (Result, LHSAddr);
4176+ }
4177+ llvm::Value *NextIVal = Builder.CreateNUWSub (
4178+ IVal, llvm::ConstantInt::get (Builder.getInt32Ty (), 1 ));
4179+ IVal->addIncoming (NextIVal, Builder.GetInsertBlock ());
4180+ CmpI = Builder.CreateICmpUGE (NextIVal, Pow2K);
4181+ Builder.CreateCondBr (CmpI, InnerLoopBB, InnerExitBB);
4182+ emitBlock (InnerExitBB, Builder.GetInsertBlock ()->getParent ());
4183+ llvm::Value *Next = Builder.CreateNUWAdd (
4184+ Counter, llvm::ConstantInt::get (Counter->getType (), 1 ));
4185+ Counter->addIncoming (Next, Builder.GetInsertBlock ());
4186+ // pow2k <<= 1;
4187+ llvm::Value *NextPow2K = Builder.CreateShl (Pow2K, 1 , " " , /* HasNUW=*/ true );
4188+ Pow2K->addIncoming (NextPow2K, Builder.GetInsertBlock ());
4189+ llvm::Value *Cmp = Builder.CreateICmpNE (Next, LogVal);
4190+ Builder.CreateCondBr (Cmp, LoopBB, ExitBB);
4191+ Builder.SetInsertPoint (ExitBB->getFirstInsertionPt ());
4192+ return Error::success ();
4193+ };
4194+
4195+ // TODO: Perform finalization actions for variables. This has to be
4196+ // called for variables which have destructors/finalizers.
4197+ auto FiniCB = [&](InsertPointTy CodeGenIP) { return llvm::Error::success (); };
4198+
4199+ llvm::Value *FilterVal = Builder.getInt32 (0 );
41914200 llvm::OpenMPIRBuilder::InsertPointOrErrorTy AfterIP =
4192- createBarrier (Builder.saveIP (), llvm::omp::OMPD_barrier);
4201+ createMasked (Builder.saveIP (), BodyGenCB, FiniCB, FilterVal);
4202+
4203+ if (!AfterIP)
4204+ return AfterIP.takeError ();
4205+ Builder.restoreIP (*AfterIP);
4206+ AfterIP = createBarrier (Builder.saveIP (), llvm::omp::OMPD_barrier);
41934207
4208+ if (!AfterIP)
4209+ return AfterIP.takeError ();
4210+ Builder.restoreIP (*AfterIP);
41944211 Builder.restoreIP (FinalizeIP);
41954212 emitScanBasedDirectiveFinalsIR (ReductionInfos);
41964213 FinalizeIP = Builder.saveIP ();
@@ -4204,7 +4221,6 @@ Error OpenMPIRBuilder::emitScanBasedDirectiveIR(
42044221
42054222 {
42064223 // Emit loop with input phase:
4207- // #pragma omp ...
42084224 // for (i: 0..<num_iters>) {
42094225 // <input phase>;
42104226 // buffer[i] = red;
@@ -4215,6 +4231,11 @@ Error OpenMPIRBuilder::emitScanBasedDirectiveIR(
42154231 return Result;
42164232 }
42174233 {
4234+ // Emit loop with scan phase:
4235+ // for (i: 0..<num_iters>) {
4236+ // red = buffer[i];
4237+ // <scan phase>;
4238+ // }
42184239 ScanInfo.OMPFirstScanLoop = false ;
42194240 auto Result = ScanLoopGen (Builder.saveIP ());
42204241 if (Result)
@@ -4224,17 +4245,17 @@ Error OpenMPIRBuilder::emitScanBasedDirectiveIR(
42244245}
42254246
42264247void OpenMPIRBuilder::createScanBBs () {
4227- auto fun = Builder.GetInsertBlock ()->getParent ();
4248+ Function *Fun = Builder.GetInsertBlock ()->getParent ();
42284249 ScanInfo.OMPScanExitBlock =
4229- BasicBlock::Create (fun ->getContext (), " omp.exit.inscan.bb" );
4250+ BasicBlock::Create (Fun ->getContext (), " omp.exit.inscan.bb" );
42304251 ScanInfo.OMPScanDispatch =
4231- BasicBlock::Create (fun ->getContext (), " omp.inscan.dispatch" );
4252+ BasicBlock::Create (Fun ->getContext (), " omp.inscan.dispatch" );
42324253 ScanInfo.OMPAfterScanBlock =
4233- BasicBlock::Create (fun ->getContext (), " omp.after.scan.bb" );
4254+ BasicBlock::Create (Fun ->getContext (), " omp.after.scan.bb" );
42344255 ScanInfo.OMPBeforeScanBlock =
4235- BasicBlock::Create (fun ->getContext (), " omp.before.scan.bb" );
4256+ BasicBlock::Create (Fun ->getContext (), " omp.before.scan.bb" );
42364257 ScanInfo.OMPScanLoopExit =
4237- BasicBlock::Create (fun ->getContext (), " omp.scan.loop.exit" );
4258+ BasicBlock::Create (Fun ->getContext (), " omp.scan.loop.exit" );
42384259}
42394260
42404261CanonicalLoopInfo *OpenMPIRBuilder::createLoopSkeleton (
@@ -5454,7 +5475,7 @@ OpenMPIRBuilder::tileLoops(DebugLoc DL, ArrayRef<CanonicalLoopInfo *> Loops,
54545475 // TODO: It would be sufficient to only sink them into body of the
54555476 // corresponding tile loop.
54565477 SmallVector<std::pair<BasicBlock *, BasicBlock *>, 4 > InbetweenCode;
5457- for (size_t i = 0 ; i < NumLoops - 1 ; ++i) {
5478+ for (int i = 0 ; i < NumLoops - 1 ; ++i) {
54585479 CanonicalLoopInfo *Surrounding = Loops[i];
54595480 CanonicalLoopInfo *Nested = Loops[i + 1 ];
54605481
@@ -5467,7 +5488,7 @@ OpenMPIRBuilder::tileLoops(DebugLoc DL, ArrayRef<CanonicalLoopInfo *> Loops,
54675488 Builder.SetCurrentDebugLocation (DL);
54685489 Builder.restoreIP (OutermostLoop->getPreheaderIP ());
54695490 SmallVector<Value *, 4 > FloorCount, FloorRems;
5470- for (size_t i = 0 ; i < NumLoops; ++i) {
5491+ for (int i = 0 ; i < NumLoops; ++i) {
54715492 Value *TileSize = TileSizes[i];
54725493 Value *OrigTripCount = OrigTripCounts[i];
54735494 Type *IVType = OrigTripCount->getType ();
0 commit comments