Skip to content

Commit 20b2b94

Browse files
committed
Adding a few corrections
1 parent 3072fbc commit 20b2b94

File tree

5 files changed

+561
-107
lines changed

5 files changed

+561
-107
lines changed

llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2715,7 +2715,7 @@ class OpenMPIRBuilder {
27152715
/// \param ScanVars Scan Variables.
27162716
/// \param IsInclusive Whether it is an inclusive or exclusive scan.
27172717
///
2718-
/// \returns The insertion position *after* the masked.
2718+
/// \returns The insertion position *after* the scan.
27192719
InsertPointOrErrorTy createScan(const LocationDescription &Loc,
27202720
InsertPointTy AllocaIP,
27212721
ArrayRef<llvm::Value *> ScanVars,

llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp

Lines changed: 113 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -4103,94 +4103,111 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::emitScanReduction(
41034103
const LocationDescription &Loc, InsertPointTy &FinalizeIP,
41044104
SmallVector<llvm::OpenMPIRBuilder::ReductionInfo> ReductionInfos) {
41054105

4106-
llvm::Value *spanDiff = ScanInfo.Span;
4107-
41084106
if (!updateToLocation(Loc))
41094107
return Loc.IP;
4110-
auto curFn = Builder.GetInsertBlock()->getParent();
4111-
// for (int k = 0; k <= ceil(log2(n)); ++k)
4112-
llvm::BasicBlock *LoopBB =
4113-
BasicBlock::Create(curFn->getContext(), "omp.outer.log.scan.body");
4114-
llvm::BasicBlock *ExitBB =
4115-
BasicBlock::Create(curFn->getContext(), "omp.outer.log.scan.exit");
4116-
llvm::Function *F = llvm::Intrinsic::getOrInsertDeclaration(
4117-
Builder.GetInsertBlock()->getModule(),
4118-
(llvm::Intrinsic::ID)llvm::Intrinsic::log2, Builder.getDoubleTy());
4119-
llvm::BasicBlock *InputBB = Builder.GetInsertBlock();
4120-
ConstantInt *One = ConstantInt::get(Builder.getInt32Ty(), 1);
4121-
llvm::Value *span = ScanInfo.Span; // Builder.CreateAdd(spanDiff, One);
4122-
llvm::Value *Arg = Builder.CreateUIToFP(span, Builder.getDoubleTy());
4123-
llvm::Value *LogVal = emitNoUnwindRuntimeCall(F, Arg, "");
4124-
F = llvm::Intrinsic::getOrInsertDeclaration(
4125-
Builder.GetInsertBlock()->getModule(),
4126-
(llvm::Intrinsic::ID)llvm::Intrinsic::ceil, Builder.getDoubleTy());
4127-
LogVal = emitNoUnwindRuntimeCall(F, LogVal, "");
4128-
LogVal = Builder.CreateFPToUI(LogVal, Builder.getInt32Ty());
4129-
llvm::Value *NMin1 =
4130-
Builder.CreateNUWSub(span, llvm::ConstantInt::get(span->getType(), 1));
4131-
Builder.SetInsertPoint(InputBB);
4132-
Builder.CreateBr(LoopBB);
4133-
emitBlock(LoopBB, Builder.GetInsertBlock()->getParent());
4134-
Builder.SetInsertPoint(LoopBB);
4135-
4136-
PHINode *Counter = Builder.CreatePHI(Builder.getInt32Ty(), 2);
4137-
//// size pow2k = 1;
4138-
PHINode *Pow2K = Builder.CreatePHI(Builder.getInt32Ty(), 2);
4139-
Counter->addIncoming(llvm::ConstantInt::get(Builder.getInt32Ty(), 0),
4108+
auto BodyGenCB = [&](InsertPointTy AllocaIP,
4109+
InsertPointTy CodeGenIP) -> Error {
4110+
Builder.restoreIP(CodeGenIP);
4111+
auto CurFn = Builder.GetInsertBlock()->getParent();
4112+
// for (int k = 0; k <= ceil(log2(n)); ++k)
4113+
llvm::BasicBlock *LoopBB =
4114+
BasicBlock::Create(CurFn->getContext(), "omp.outer.log.scan.body");
4115+
llvm::BasicBlock *ExitBB =
4116+
splitBB(Builder, false, "omp.outer.log.scan.exit");
4117+
llvm::Function *F = llvm::Intrinsic::getOrInsertDeclaration(
4118+
Builder.GetInsertBlock()->getModule(),
4119+
(llvm::Intrinsic::ID)llvm::Intrinsic::log2, Builder.getDoubleTy());
4120+
llvm::BasicBlock *InputBB = Builder.GetInsertBlock();
4121+
llvm::Value *Arg =
4122+
Builder.CreateUIToFP(ScanInfo.Span, Builder.getDoubleTy());
4123+
llvm::Value *LogVal = emitNoUnwindRuntimeCall(F, Arg, "");
4124+
F = llvm::Intrinsic::getOrInsertDeclaration(
4125+
Builder.GetInsertBlock()->getModule(),
4126+
(llvm::Intrinsic::ID)llvm::Intrinsic::ceil, Builder.getDoubleTy());
4127+
LogVal = emitNoUnwindRuntimeCall(F, LogVal, "");
4128+
LogVal = Builder.CreateFPToUI(LogVal, Builder.getInt32Ty());
4129+
llvm::Value *NMin1 = Builder.CreateNUWSub(
4130+
ScanInfo.Span, llvm::ConstantInt::get(ScanInfo.Span->getType(), 1));
4131+
Builder.SetInsertPoint(InputBB);
4132+
Builder.CreateBr(LoopBB);
4133+
emitBlock(LoopBB, Builder.GetInsertBlock()->getParent());
4134+
Builder.SetInsertPoint(LoopBB);
4135+
4136+
PHINode *Counter = Builder.CreatePHI(Builder.getInt32Ty(), 2);
4137+
//// size pow2k = 1;
4138+
PHINode *Pow2K = Builder.CreatePHI(Builder.getInt32Ty(), 2);
4139+
Counter->addIncoming(llvm::ConstantInt::get(Builder.getInt32Ty(), 0),
4140+
InputBB);
4141+
Pow2K->addIncoming(llvm::ConstantInt::get(Builder.getInt32Ty(), 1),
41404142
InputBB);
4141-
Pow2K->addIncoming(llvm::ConstantInt::get(Builder.getInt32Ty(), 1), InputBB);
4142-
//// for (size i = n - 1; i >= 2 ^ k; --i)
4143-
//// tmp[i] op= tmp[i-pow2k];
4144-
llvm::BasicBlock *InnerLoopBB =
4145-
BasicBlock::Create(curFn->getContext(), "omp.inner.log.scan.body");
4146-
llvm::BasicBlock *InnerExitBB =
4147-
BasicBlock::Create(curFn->getContext(), "omp.inner.log.scan.exit");
4148-
llvm::Value *CmpI = Builder.CreateICmpUGE(NMin1, Pow2K);
4149-
Builder.CreateCondBr(CmpI, InnerLoopBB, InnerExitBB);
4150-
emitBlock(InnerLoopBB, Builder.GetInsertBlock()->getParent());
4151-
Builder.SetInsertPoint(InnerLoopBB);
4152-
auto *IVal = Builder.CreatePHI(Builder.getInt32Ty(), 2);
4153-
IVal->addIncoming(NMin1, LoopBB);
4154-
unsigned int defaultAS = M.getDataLayout().getProgramAddressSpace();
4155-
for (ReductionInfo RedInfo : ReductionInfos) {
4156-
Value *ReductionVal = RedInfo.PrivateVariable;
4157-
Value *Buff = ScanInfo.ReductionVarToScanBuffs[ReductionVal];
4158-
Type *DestTy = RedInfo.ElementType;
4159-
Value *IV = Builder.CreateAdd(IVal, Builder.getInt32(1));
4160-
Value *LHSPtr = Builder.CreateInBoundsGEP(DestTy, Buff, IV, "arrayOffset");
4161-
Value *OffsetIval = Builder.CreateNUWSub(IV, Pow2K);
4162-
Value *RHSPtr =
4163-
Builder.CreateInBoundsGEP(DestTy, Buff, OffsetIval, "arrayOffset");
4164-
Value *LHS = Builder.CreateLoad(DestTy, LHSPtr);
4165-
Value *RHS = Builder.CreateLoad(DestTy, RHSPtr);
4166-
Value *LHSAddr = Builder.CreatePointerBitCastOrAddrSpaceCast(
4167-
LHSPtr, RHS->getType()->getPointerTo(defaultAS));
4168-
llvm::Value *Result;
4169-
InsertPointOrErrorTy AfterIP =
4170-
RedInfo.ReductionGen(Builder.saveIP(), LHS, RHS, Result);
4171-
if (!AfterIP)
4172-
return AfterIP.takeError();
4173-
Builder.CreateStore(Result, LHSAddr);
4174-
}
4175-
llvm::Value *NextIVal = Builder.CreateNUWSub(
4176-
IVal, llvm::ConstantInt::get(Builder.getInt32Ty(), 1));
4177-
IVal->addIncoming(NextIVal, Builder.GetInsertBlock());
4178-
CmpI = Builder.CreateICmpUGE(NextIVal, Pow2K);
4179-
Builder.CreateCondBr(CmpI, InnerLoopBB, InnerExitBB);
4180-
emitBlock(InnerExitBB, Builder.GetInsertBlock()->getParent());
4181-
llvm::Value *Next = Builder.CreateNUWAdd(
4182-
Counter, llvm::ConstantInt::get(Counter->getType(), 1));
4183-
Counter->addIncoming(Next, Builder.GetInsertBlock());
4184-
// pow2k <<= 1;
4185-
llvm::Value *NextPow2K = Builder.CreateShl(Pow2K, 1, "", /*HasNUW=*/true);
4186-
Pow2K->addIncoming(NextPow2K, Builder.GetInsertBlock());
4187-
llvm::Value *Cmp = Builder.CreateICmpNE(Next, LogVal);
4188-
Builder.CreateCondBr(Cmp, LoopBB, ExitBB);
4189-
emitBlock(ExitBB, Builder.GetInsertBlock()->getParent());
4190-
Builder.SetInsertPoint(ExitBB);
4143+
//// for (size i = n - 1; i >= 2 ^ k; --i)
4144+
//// tmp[i] op= tmp[i-pow2k];
4145+
llvm::BasicBlock *InnerLoopBB =
4146+
BasicBlock::Create(CurFn->getContext(), "omp.inner.log.scan.body");
4147+
llvm::BasicBlock *InnerExitBB =
4148+
BasicBlock::Create(CurFn->getContext(), "omp.inner.log.scan.exit");
4149+
llvm::Value *CmpI = Builder.CreateICmpUGE(NMin1, Pow2K);
4150+
Builder.CreateCondBr(CmpI, InnerLoopBB, InnerExitBB);
4151+
emitBlock(InnerLoopBB, Builder.GetInsertBlock()->getParent());
4152+
Builder.SetInsertPoint(InnerLoopBB);
4153+
auto *IVal = Builder.CreatePHI(Builder.getInt32Ty(), 2);
4154+
IVal->addIncoming(NMin1, LoopBB);
4155+
unsigned int defaultAS = M.getDataLayout().getProgramAddressSpace();
4156+
for (ReductionInfo RedInfo : ReductionInfos) {
4157+
Value *ReductionVal = RedInfo.PrivateVariable;
4158+
Value *Buff = ScanInfo.ReductionVarToScanBuffs[ReductionVal];
4159+
Type *DestTy = RedInfo.ElementType;
4160+
Value *IV = Builder.CreateAdd(IVal, Builder.getInt32(1));
4161+
Value *LHSPtr =
4162+
Builder.CreateInBoundsGEP(DestTy, Buff, IV, "arrayOffset");
4163+
Value *OffsetIval = Builder.CreateNUWSub(IV, Pow2K);
4164+
Value *RHSPtr =
4165+
Builder.CreateInBoundsGEP(DestTy, Buff, OffsetIval, "arrayOffset");
4166+
Value *LHS = Builder.CreateLoad(DestTy, LHSPtr);
4167+
Value *RHS = Builder.CreateLoad(DestTy, RHSPtr);
4168+
Value *LHSAddr = Builder.CreatePointerBitCastOrAddrSpaceCast(
4169+
LHSPtr, RHS->getType()->getPointerTo(defaultAS));
4170+
llvm::Value *Result;
4171+
InsertPointOrErrorTy AfterIP =
4172+
RedInfo.ReductionGen(Builder.saveIP(), LHS, RHS, Result);
4173+
if (!AfterIP)
4174+
return AfterIP.takeError();
4175+
Builder.CreateStore(Result, LHSAddr);
4176+
}
4177+
llvm::Value *NextIVal = Builder.CreateNUWSub(
4178+
IVal, llvm::ConstantInt::get(Builder.getInt32Ty(), 1));
4179+
IVal->addIncoming(NextIVal, Builder.GetInsertBlock());
4180+
CmpI = Builder.CreateICmpUGE(NextIVal, Pow2K);
4181+
Builder.CreateCondBr(CmpI, InnerLoopBB, InnerExitBB);
4182+
emitBlock(InnerExitBB, Builder.GetInsertBlock()->getParent());
4183+
llvm::Value *Next = Builder.CreateNUWAdd(
4184+
Counter, llvm::ConstantInt::get(Counter->getType(), 1));
4185+
Counter->addIncoming(Next, Builder.GetInsertBlock());
4186+
// pow2k <<= 1;
4187+
llvm::Value *NextPow2K = Builder.CreateShl(Pow2K, 1, "", /*HasNUW=*/true);
4188+
Pow2K->addIncoming(NextPow2K, Builder.GetInsertBlock());
4189+
llvm::Value *Cmp = Builder.CreateICmpNE(Next, LogVal);
4190+
Builder.CreateCondBr(Cmp, LoopBB, ExitBB);
4191+
Builder.SetInsertPoint(ExitBB->getFirstInsertionPt());
4192+
return Error::success();
4193+
};
4194+
4195+
// TODO: Perform finalization actions for variables. This has to be
4196+
// called for variables which have destructors/finalizers.
4197+
auto FiniCB = [&](InsertPointTy CodeGenIP) { return llvm::Error::success(); };
4198+
4199+
llvm::Value *FilterVal = Builder.getInt32(0);
41914200
llvm::OpenMPIRBuilder::InsertPointOrErrorTy AfterIP =
4192-
createBarrier(Builder.saveIP(), llvm::omp::OMPD_barrier);
4201+
createMasked(Builder.saveIP(), BodyGenCB, FiniCB, FilterVal);
4202+
4203+
if (!AfterIP)
4204+
return AfterIP.takeError();
4205+
Builder.restoreIP(*AfterIP);
4206+
AfterIP = createBarrier(Builder.saveIP(), llvm::omp::OMPD_barrier);
41934207

4208+
if (!AfterIP)
4209+
return AfterIP.takeError();
4210+
Builder.restoreIP(*AfterIP);
41944211
Builder.restoreIP(FinalizeIP);
41954212
emitScanBasedDirectiveFinalsIR(ReductionInfos);
41964213
FinalizeIP = Builder.saveIP();
@@ -4204,7 +4221,6 @@ Error OpenMPIRBuilder::emitScanBasedDirectiveIR(
42044221

42054222
{
42064223
// Emit loop with input phase:
4207-
// #pragma omp ...
42084224
// for (i: 0..<num_iters>) {
42094225
// <input phase>;
42104226
// buffer[i] = red;
@@ -4215,6 +4231,11 @@ Error OpenMPIRBuilder::emitScanBasedDirectiveIR(
42154231
return Result;
42164232
}
42174233
{
4234+
// Emit loop with scan phase:
4235+
// for (i: 0..<num_iters>) {
4236+
// red = buffer[i];
4237+
// <scan phase>;
4238+
// }
42184239
ScanInfo.OMPFirstScanLoop = false;
42194240
auto Result = ScanLoopGen(Builder.saveIP());
42204241
if (Result)
@@ -4224,17 +4245,17 @@ Error OpenMPIRBuilder::emitScanBasedDirectiveIR(
42244245
}
42254246

42264247
void OpenMPIRBuilder::createScanBBs() {
4227-
auto fun = Builder.GetInsertBlock()->getParent();
4248+
Function *Fun = Builder.GetInsertBlock()->getParent();
42284249
ScanInfo.OMPScanExitBlock =
4229-
BasicBlock::Create(fun->getContext(), "omp.exit.inscan.bb");
4250+
BasicBlock::Create(Fun->getContext(), "omp.exit.inscan.bb");
42304251
ScanInfo.OMPScanDispatch =
4231-
BasicBlock::Create(fun->getContext(), "omp.inscan.dispatch");
4252+
BasicBlock::Create(Fun->getContext(), "omp.inscan.dispatch");
42324253
ScanInfo.OMPAfterScanBlock =
4233-
BasicBlock::Create(fun->getContext(), "omp.after.scan.bb");
4254+
BasicBlock::Create(Fun->getContext(), "omp.after.scan.bb");
42344255
ScanInfo.OMPBeforeScanBlock =
4235-
BasicBlock::Create(fun->getContext(), "omp.before.scan.bb");
4256+
BasicBlock::Create(Fun->getContext(), "omp.before.scan.bb");
42364257
ScanInfo.OMPScanLoopExit =
4237-
BasicBlock::Create(fun->getContext(), "omp.scan.loop.exit");
4258+
BasicBlock::Create(Fun->getContext(), "omp.scan.loop.exit");
42384259
}
42394260

42404261
CanonicalLoopInfo *OpenMPIRBuilder::createLoopSkeleton(
@@ -5454,7 +5475,7 @@ OpenMPIRBuilder::tileLoops(DebugLoc DL, ArrayRef<CanonicalLoopInfo *> Loops,
54545475
// TODO: It would be sufficient to only sink them into body of the
54555476
// corresponding tile loop.
54565477
SmallVector<std::pair<BasicBlock *, BasicBlock *>, 4> InbetweenCode;
5457-
for (size_t i = 0; i < NumLoops - 1; ++i) {
5478+
for (int i = 0; i < NumLoops - 1; ++i) {
54585479
CanonicalLoopInfo *Surrounding = Loops[i];
54595480
CanonicalLoopInfo *Nested = Loops[i + 1];
54605481

@@ -5467,7 +5488,7 @@ OpenMPIRBuilder::tileLoops(DebugLoc DL, ArrayRef<CanonicalLoopInfo *> Loops,
54675488
Builder.SetCurrentDebugLocation(DL);
54685489
Builder.restoreIP(OutermostLoop->getPreheaderIP());
54695490
SmallVector<Value *, 4> FloorCount, FloorRems;
5470-
for (size_t i = 0; i < NumLoops; ++i) {
5491+
for (int i = 0; i < NumLoops; ++i) {
54715492
Value *TileSize = TileSizes[i];
54725493
Value *OrigTripCount = OrigTripCounts[i];
54735494
Type *IVType = OrigTripCount->getType();

llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5376,29 +5376,32 @@ TEST_F(OpenMPIRBuilderTest, ScanReduction) {
53765376
false, Builder.saveIP(), "scan"));
53775377
Loops = loopsVec;
53785378
EXPECT_EQ(Loops.size(), 2U);
5379-
auto inputLoop = Loops.front();
5380-
auto scanLoop = Loops.back();
5381-
Builder.restoreIP(scanLoop->getAfterIP());
5382-
inputLoop->assertOK();
5383-
scanLoop->assertOK();
5384-
5385-
//// Verify control flow structure (in addition to Loop->assertOK()).
5386-
EXPECT_EQ(inputLoop->getPreheader()->getSinglePredecessor(),
5379+
CanonicalLoopInfo *InputLoop = Loops.front();
5380+
CanonicalLoopInfo *ScanLoop = Loops.back();
5381+
Builder.restoreIP(ScanLoop->getAfterIP());
5382+
InputLoop->assertOK();
5383+
ScanLoop->assertOK();
5384+
5385+
EXPECT_EQ(InputLoop->getPreheader()->getSinglePredecessor(),
53875386
&F->getEntryBlock());
5388-
EXPECT_EQ(scanLoop->getAfter(), Builder.GetInsertBlock());
5387+
EXPECT_EQ(ScanLoop->getAfter(), Builder.GetInsertBlock());
53895388
EXPECT_EQ(NumBodiesGenerated, 2U);
53905389
SmallVector<OpenMPIRBuilder::ReductionInfo> reductionInfos = {
53915390
{Builder.getFloatTy(), origVar, scanVar,
53925391
/*EvaluationKind=*/OpenMPIRBuilder::EvalKind::Scalar, sumReduction,
53935392
/*ReductionGenClang=*/nullptr, sumAtomicReduction}};
5394-
auto FinalizeIP = scanLoop->getAfterIP();
5395-
OpenMPIRBuilder::LocationDescription RedLoc({inputLoop->getAfterIP(), DL});
5393+
auto FinalizeIP = ScanLoop->getAfterIP();
5394+
OpenMPIRBuilder::LocationDescription RedLoc({InputLoop->getAfterIP(), DL});
53965395
llvm::BasicBlock *Cont = splitBB(Builder, false, "omp.scan.loop.cont");
53975396
ASSERT_EXPECTED_INIT(
53985397
InsertPointTy, retIp,
53995398
OMPBuilder.emitScanReduction(RedLoc, FinalizeIP, reductionInfos));
54005399
Builder.restoreIP(retIp);
54015400
Builder.CreateBr(Cont);
5401+
SmallVector<CallInst *> MaskedCalls;
5402+
findCalls(F, omp::RuntimeFunction::OMPRTL___kmpc_masked, OMPBuilder,
5403+
MaskedCalls);
5404+
ASSERT_EQ(MaskedCalls.size(), 1u);
54025405
}
54035406

54045407
TEST_F(OpenMPIRBuilderTest, CreateTwoReductions) {

mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,8 @@
4747

4848
using namespace mlir;
4949

50-
llvm::SmallDenseMap<llvm::Value *, llvm::Type *> ReductionVarToType;
50+
llvm::SmallDenseMap<llvm::Value *, llvm::Type *> ReductionVarToType;
51+
llvm::OpenMPIRBuilder::InsertPointTy parallelAllocaIP;// TODO: change this alloca IP to point to originalvar allocaIP. ReductionDecl need to be linked to scan var.
5152
namespace {
5253
static llvm::omp::ScheduleKind
5354
convertToScheduleKind(std::optional<omp::ClauseScheduleKind> schedKind) {
@@ -2578,6 +2579,7 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
25782579

25792580
llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
25802581
findAllocaInsertPoint(builder, moduleTranslation);
2582+
parallelAllocaIP = allocaIP;
25812583
llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
25822584

25832585
llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
@@ -2619,9 +2621,13 @@ convertOmpScan(Operation &opInst, llvm::IRBuilderBase &builder,
26192621
llvm::Value *llvmVal = moduleTranslation.lookupValue(val);
26202622
llvmScanVars.push_back(llvmVal);
26212623
llvmScanVarsType.push_back(ReductionVarToType[llvmVal]);
2624+
val.getDefiningOp();
26222625
}
2623-
llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
2624-
findAllocaInsertPoint(builder, moduleTranslation);
2626+
auto parallelOp = scanOp->getParentOfType<omp::ParallelOp>();
2627+
if (!parallelOp) {
2628+
return failure();
2629+
}
2630+
llvm::OpenMPIRBuilder::InsertPointTy allocaIP = parallelAllocaIP;
26252631
llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
26262632
llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
26272633
moduleTranslation.getOpenMPBuilder()->createScan(

0 commit comments

Comments
 (0)