Skip to content

Commit 097a869

Browse files
committed
Make convertOmpTarget use allocatePrivateVars.
Allows more code reuse by generalizing `allocatePrivateVars` to serve `target` op conversion.
1 parent f350b9c commit 097a869

File tree

6 files changed

+106
-96
lines changed

6 files changed

+106
-96
lines changed

llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6807,8 +6807,11 @@ static Expected<Function *> createOutlinedFunction(
68076807
OMPBuilder.ConstantAllocaRaiseCandidates.emplace_back(Func);
68086808

68096809
// Insert target deinit call in the device compilation pass.
6810-
llvm::OpenMPIRBuilder::InsertPointOrErrorTy AfterIP =
6811-
CBFunc(Builder.saveIP(), Builder.saveIP());
6810+
BasicBlock *OutlinedBodyBB =
6811+
splitBB(Builder, /*CreateBranch=*/true, "outlined.body");
6812+
llvm::OpenMPIRBuilder::InsertPointOrErrorTy AfterIP = CBFunc(
6813+
Builder.saveIP(),
6814+
OpenMPIRBuilder::InsertPointTy(OutlinedBodyBB, OutlinedBodyBB->begin()));
68126815
if (!AfterIP)
68136816
return AfterIP.takeError();
68146817
Builder.restoreIP(*AfterIP);

llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6358,7 +6358,13 @@ TEST_F(OpenMPIRBuilderTest, TargetRegionDevice) {
63586358
auto *Load2 = Load1->getNextNode();
63596359
EXPECT_TRUE(isa<LoadInst>(Load2));
63606360

6361-
auto *Value1 = Load2->getNextNode();
6361+
auto *OutlinedBlockBr = Load2->getNextNode();
6362+
EXPECT_TRUE(isa<BranchInst>(OutlinedBlockBr));
6363+
6364+
auto *OutlinedBlock = OutlinedBlockBr->getSuccessor(0);
6365+
EXPECT_EQ(OutlinedBlock->getName(), "outlined.body");
6366+
6367+
auto *Value1 = OutlinedBlock->getFirstNonPHI();
63626368
EXPECT_EQ(Value1, Value);
63636369
EXPECT_EQ(Value1->getNextNode(), TargetStore);
63646370
auto *Deinit = TargetStore->getNextNode();
@@ -6510,7 +6516,14 @@ TEST_F(OpenMPIRBuilderTest, ConstantAllocaRaise) {
65106516
EXPECT_EQ(UserCodeBlock->getName(), "user_code.entry");
65116517
auto *Load1 = UserCodeBlock->getFirstNonPHI();
65126518
EXPECT_TRUE(isa<LoadInst>(Load1));
6513-
auto *Load2 = Load1->getNextNode();
6519+
6520+
auto *OutlinedBlockBr = Load1->getNextNode();
6521+
EXPECT_TRUE(isa<BranchInst>(OutlinedBlockBr));
6522+
6523+
auto *OutlinedBlock = OutlinedBlockBr->getSuccessor(0);
6524+
EXPECT_EQ(OutlinedBlock->getName(), "outlined.body");
6525+
6526+
auto *Load2 = OutlinedBlock->getFirstNonPHI();
65146527
EXPECT_TRUE(isa<LoadInst>(Load2));
65156528
EXPECT_EQ(Load2, Value);
65166529
EXPECT_EQ(Load2->getNextNode(), TargetStore);

mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp

Lines changed: 76 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -1286,6 +1286,44 @@ static LogicalResult allocAndInitializeReductionVars(
12861286
isByRef, deferredStores);
12871287
}
12881288

1289+
/// Return the llvm::Value * corresponding to the `privateVar` that
1290+
/// is being privatized. It isn't always as simple as looking up
1291+
/// moduleTranslation with privateVar. For instance, in case of
1292+
/// an allocatable, the descriptor for the allocatable is privatized.
1293+
/// This descriptor is mapped using an MapInfoOp. So, this function
1294+
/// will return a pointer to the llvm::Value corresponding to the
1295+
/// block argument for the mapped descriptor.
1296+
static llvm::Value *
1297+
findAssociatedValue(Value privateVar, llvm::IRBuilderBase &builder,
1298+
LLVM::ModuleTranslation &moduleTranslation,
1299+
omp::TargetOp targetOp = nullptr,
1300+
llvm::DenseMap<Value, int> *mappedPrivateVars = nullptr) {
1301+
if (mappedPrivateVars != nullptr && mappedPrivateVars->contains(privateVar)) {
1302+
int blockArgIndex = (*mappedPrivateVars)[privateVar];
1303+
Value blockArg = targetOp.getRegion().getArgument(blockArgIndex);
1304+
Type privVarType = privateVar.getType();
1305+
Type blockArgType = blockArg.getType();
1306+
assert(isa<LLVM::LLVMPointerType>(blockArgType) &&
1307+
"A block argument corresponding to a mapped var should have "
1308+
"!llvm.ptr type");
1309+
1310+
if (privVarType == blockArg.getType())
1311+
return moduleTranslation.lookupValue(blockArg);
1312+
1313+
if (!isa<LLVM::LLVMPointerType>(privVarType)) {
1314+
// This typically happens when the privatized type is lowered from
1315+
// boxchar<KIND> and gets lowered to !llvm.struct<(ptr, i64)>. That is the
1316+
// struct/pair is passed by value. But, mapped values are passed only as
1317+
// pointers, so before we privatize, we must load the pointer.
1318+
llvm::Value *load =
1319+
builder.CreateLoad(moduleTranslation.convertType(privVarType),
1320+
moduleTranslation.lookupValue(blockArg));
1321+
return load;
1322+
}
1323+
}
1324+
return moduleTranslation.lookupValue(privateVar);
1325+
}
1326+
12891327
/// Allocate delayed private variables. Returns the basic block which comes
12901328
/// after all of these allocations. llvm::Value * for each of these private
12911329
/// variables are populated in llvmPrivateVars.
@@ -1296,7 +1334,9 @@ allocatePrivateVars(llvm::IRBuilderBase &builder,
12961334
MutableArrayRef<omp::PrivateClauseOp> privateDecls,
12971335
MutableArrayRef<mlir::Value> mlirPrivateVars,
12981336
llvm::SmallVectorImpl<llvm::Value *> &llvmPrivateVars,
1299-
const llvm::OpenMPIRBuilder::InsertPointTy &allocaIP) {
1337+
const llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
1338+
omp::TargetOp targetOp = nullptr,
1339+
llvm::DenseMap<Value, int> *mappedPrivateVars = nullptr) {
13001340
llvm::IRBuilderBase::InsertPointGuard guard(builder);
13011341
// Allocate private vars
13021342
llvm::BranchInst *allocaTerminator =
@@ -1326,7 +1366,8 @@ allocatePrivateVars(llvm::IRBuilderBase &builder,
13261366
Region &allocRegion = privDecl.getAllocRegion();
13271367

13281368
// map allocation region block argument
1329-
llvm::Value *nonPrivateVar = moduleTranslation.lookupValue(mlirPrivVar);
1369+
llvm::Value *nonPrivateVar = findAssociatedValue(
1370+
mlirPrivVar, builder, moduleTranslation, targetOp, mappedPrivateVars);
13301371
assert(nonPrivateVar);
13311372
moduleTranslation.mapValue(privDecl.getAllocMoldArg(), nonPrivateVar);
13321373

@@ -1341,6 +1382,7 @@ allocatePrivateVars(llvm::IRBuilderBase &builder,
13411382
} else {
13421383
builder.SetInsertPoint(privAllocBlock->getTerminator());
13431384
}
1385+
13441386
if (failed(inlineConvertOmpRegions(allocRegion, "omp.private.alloc",
13451387
builder, moduleTranslation, &phis)))
13461388
return llvm::createStringError(
@@ -3809,43 +3851,6 @@ createDeviceArgumentAccessor(MapInfoData &mapData, llvm::Argument &arg,
38093851
return builder.saveIP();
38103852
}
38113853

3812-
/// Return the llvm::Value * corresponding to the `privateVar` that
3813-
/// is being privatized. It isn't always as simple as looking up
3814-
/// moduleTranslation with privateVar. For instance, in case of
3815-
/// an allocatable, the descriptor for the allocatable is privatized.
3816-
/// This descriptor is mapped using an MapInfoOp. So, this function
3817-
/// will return a pointer to the llvm::Value corresponding to the
3818-
/// block argument for the mapped descriptor.
3819-
static llvm::Value *
3820-
findHostAssociatedValue(Value privateVar, omp::TargetOp targetOp,
3821-
llvm::DenseMap<Value, int> &mappedPrivateVars,
3822-
llvm::IRBuilderBase &builder,
3823-
LLVM::ModuleTranslation &moduleTranslation) {
3824-
if (mappedPrivateVars.contains(privateVar)) {
3825-
int blockArgIndex = mappedPrivateVars[privateVar];
3826-
Value blockArg = targetOp.getRegion().getArgument(blockArgIndex);
3827-
Type privVarType = privateVar.getType();
3828-
Type blockArgType = blockArg.getType();
3829-
assert(isa<LLVM::LLVMPointerType>(blockArgType) &&
3830-
"A block argument corresponding to a mapped var should have "
3831-
"!llvm.ptr type");
3832-
3833-
if (privVarType == blockArg.getType())
3834-
return moduleTranslation.lookupValue(blockArg);
3835-
3836-
if (!isa<LLVM::LLVMPointerType>(privVarType)) {
3837-
// This typically happens when the privatized type is lowered from
3838-
// boxchar<KIND> and gets lowered to !llvm.struct<(ptr, i64)>. That is the
3839-
// struct/pair is passed by value. But, mapped values are passed only as
3840-
// pointers, so before we privatize, we must load the pointer.
3841-
llvm::Value *load =
3842-
builder.CreateLoad(moduleTranslation.convertType(privVarType),
3843-
moduleTranslation.lookupValue(blockArg));
3844-
return load;
3845-
}
3846-
}
3847-
return moduleTranslation.lookupValue(privateVar);
3848-
}
38493854
static LogicalResult
38503855
convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
38513856
LLVM::ModuleTranslation &moduleTranslation) {
@@ -3949,68 +3954,49 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
39493954
attr.isStringAttribute())
39503955
llvmOutlinedFn->addFnAttr(attr);
39513956

3952-
builder.restoreIP(codeGenIP);
39533957
for (auto [arg, mapOp] : llvm::zip_equal(mapBlockArgs, mapVars)) {
39543958
auto mapInfoOp = cast<omp::MapInfoOp>(mapOp.getDefiningOp());
39553959
llvm::Value *mapOpValue =
39563960
moduleTranslation.lookupValue(mapInfoOp.getVarPtr());
39573961
moduleTranslation.mapValue(arg, mapOpValue);
39583962
}
3963+
39593964
// Do privatization after moduleTranslation has already recorded
39603965
// mapped values.
3966+
MutableArrayRef<BlockArgument> privateBlockArgs =
3967+
cast<omp::BlockArgOpenMPOpInterface>(opInst).getPrivateBlockArgs();
3968+
SmallVector<mlir::Value> mlirPrivateVars;
39613969
SmallVector<llvm::Value *> llvmPrivateVars;
3970+
SmallVector<omp::PrivateClauseOp> privateDecls;
3971+
mlirPrivateVars.reserve(privateBlockArgs.size());
3972+
llvmPrivateVars.reserve(privateBlockArgs.size());
3973+
collectPrivatizationDecls(targetOp, privateDecls);
3974+
for (mlir::Value privateVar : targetOp.getPrivateVars())
3975+
mlirPrivateVars.push_back(privateVar);
3976+
3977+
llvm::Expected<llvm::BasicBlock *> afterAllocas =
3978+
allocatePrivateVars(builder, moduleTranslation, privateBlockArgs,
3979+
privateDecls, mlirPrivateVars, llvmPrivateVars,
3980+
allocaIP, targetOp, &mappedPrivateVars);
3981+
3982+
if (handleError(afterAllocas, *targetOp).failed())
3983+
return llvm::make_error<PreviouslyReportedError>();
3984+
39623985
SmallVector<Region *> privateCleanupRegions;
3963-
if (!targetOp.getPrivateVars().empty()) {
3964-
builder.restoreIP(allocaIP);
3965-
3966-
OperandRange privateVars = targetOp.getPrivateVars();
3967-
std::optional<ArrayAttr> privateSyms = targetOp.getPrivateSyms();
3968-
MutableArrayRef<BlockArgument> privateBlockArgs =
3969-
cast<omp::BlockArgOpenMPOpInterface>(opInst).getPrivateBlockArgs();
3970-
3971-
for (auto [privVar, privatizerNameAttr, privBlockArg] :
3972-
llvm::zip_equal(privateVars, *privateSyms, privateBlockArgs)) {
3973-
3974-
SymbolRefAttr privSym = cast<SymbolRefAttr>(privatizerNameAttr);
3975-
omp::PrivateClauseOp privatizer = findPrivatizer(&opInst, privSym);
3976-
assert(privatizer.getDataSharingType() !=
3977-
omp::DataSharingClauseType::FirstPrivate &&
3978-
"unsupported privatizer");
3979-
Region &allocRegion = privatizer.getAllocRegion();
3980-
BlockArgument allocRegionArg = allocRegion.getArgument(0);
3981-
moduleTranslation.mapValue(
3982-
allocRegionArg,
3983-
findHostAssociatedValue(privVar, targetOp, mappedPrivateVars,
3984-
builder, moduleTranslation));
3985-
SmallVector<llvm::Value *, 1> yieldedValues;
3986-
if (failed(inlineConvertOmpRegions(
3987-
allocRegion, "omp.targetop.privatizer", builder,
3988-
moduleTranslation, &yieldedValues))) {
3989-
return llvm::createStringError(
3990-
"failed to inline `alloc` region of `omp.private`");
3991-
}
3992-
assert(yieldedValues.size() == 1);
3993-
llvm::Value *llvmReplacementValue = yieldedValues.front();
3994-
moduleTranslation.mapValue(privBlockArg, llvmReplacementValue);
3995-
if (!privatizer.getDeallocRegion().empty()) {
3996-
llvmPrivateVars.push_back(llvmReplacementValue);
3997-
privateCleanupRegions.push_back(&privatizer.getDeallocRegion());
3998-
}
3999-
moduleTranslation.forgetMapping(allocRegion);
4000-
builder.restoreIP(builder.saveIP());
4001-
}
4002-
}
3986+
llvm::transform(privateDecls, std::back_inserter(privateCleanupRegions),
3987+
[](omp::PrivateClauseOp privatizer) {
3988+
return &privatizer.getDeallocRegion();
3989+
});
40033990

3991+
builder.restoreIP(codeGenIP);
40043992
llvm::Expected<llvm::BasicBlock *> exitBlock = convertOmpOpRegions(
40053993
targetRegion, "omp.target", builder, moduleTranslation);
3994+
40063995
if (!exitBlock)
40073996
return exitBlock.takeError();
40083997

40093998
builder.SetInsertPoint(*exitBlock);
4010-
if (!llvmPrivateVars.empty()) {
4011-
assert(llvmPrivateVars.size() == privateCleanupRegions.size() &&
4012-
"Number of private variables needing cleanup not equal to number"
4013-
"of privatizers with dealloc regions");
3999+
if (!privateCleanupRegions.empty()) {
40144000
if (failed(inlineOmpRegionCleanup(
40154001
privateCleanupRegions, llvmPrivateVars, moduleTranslation,
40164002
builder, "omp.targetop.private.cleanup",
@@ -4020,10 +4006,10 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
40204006
"op in the target region");
40214007
}
40224008
}
4023-
return builder.saveIP();
4009+
4010+
return InsertPointTy(exitBlock.get(), exitBlock.get()->end());
40244011
};
40254012

4026-
llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
40274013
StringRef parentName = parentFn.getName();
40284014

40294015
llvm::TargetRegionEntryInfo entryInfo;
@@ -4034,9 +4020,6 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
40344020
int32_t defaultValTeams = -1;
40354021
int32_t defaultValThreads = 0;
40364022

4037-
llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
4038-
findAllocaInsertPoint(builder, moduleTranslation);
4039-
40404023
MapInfoData mapData;
40414024
collectMapDataFromMapOperands(mapData, mapVars, moduleTranslation, dl,
40424025
builder);
@@ -4084,6 +4067,10 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
40844067
buildDependData(targetOp.getDependKinds(), targetOp.getDependVars(),
40854068
moduleTranslation, dds);
40864069

4070+
llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
4071+
findAllocaInsertPoint(builder, moduleTranslation);
4072+
llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
4073+
40874074
llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
40884075
moduleTranslation.getOpenMPBuilder()->createTarget(
40894076
ompLoc, isOffloadEntry, allocaIP, builder.saveIP(), entryInfo,

mlir/test/Target/LLVMIR/omptarget-byref-bycopy-generation-device.mlir

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,12 @@ module attributes {omp.is_target_device = true} {
3333

3434
// CHECK: user_code.entry: ; preds = %entry
3535
// CHECK: %[[LOAD_BYREF:.*]] = load ptr, ptr %[[ALLOCA_BYREF]], align 8
36+
// CHECK: br label %outlined.body
37+
38+
// CHECK: outlined.body:
3639
// CHECK: br label %omp.target
3740

38-
// CHECK: omp.target: ; preds = %user_code.entry
41+
// CHECK: omp.target:
3942
// CHECK: %[[VAL_LOAD_BYCOPY:.*]] = load i32, ptr %[[ALLOCA_BYCOPY]], align 4
4043
// CHECK: store i32 %[[VAL_LOAD_BYCOPY]], ptr %[[LOAD_BYREF]], align 4
4144
// CHECK: br label %omp.region.cont

mlir/test/Target/LLVMIR/omptarget-declare-target-llvm-device.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ module attributes {omp.is_target_device = true} {
1717
llvm.func @_QQmain() attributes {} {
1818
%0 = llvm.mlir.addressof @_QMtest_0Esp : !llvm.ptr
1919

20-
// CHECK-DAG: omp.target: ; preds = %user_code.entry
20+
// CHECK-DAG: omp.target: ; preds = %outlined.body
2121
// CHECK-DAG: %[[V:.*]] = load ptr, ptr @_QMtest_0Esp_decl_tgt_ref_ptr, align 8
2222
// CHECK-DAG: store i32 1, ptr %[[V]], align 4
2323
// CHECK-DAG: br label %omp.region.cont

mlir/test/Target/LLVMIR/openmp-target-use-device-nested.mlir

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,11 @@
1313
// CHECK: user_code.entry: ; preds = %[[VAL_10:.*]]
1414
// CHECK-NEXT: %[[VAL_11:.*]] = load ptr, ptr %[[VAL_3]], align 8
1515
// CHECK-NEXT: br label %[[VAL_12:.*]]
16-
// CHECK: omp.target: ; preds = %[[VAL_8]]
16+
17+
// CHECK: [[VAL_12]]:
18+
// CHECK-NEXT: br label %[[TARGET_REG_ENTRY:.*]]
19+
20+
// CHECK: [[TARGET_REG_ENTRY]]: ; preds = %[[VAL_12]]
1721
// CHECK-NEXT: %[[VAL_13:.*]] = load ptr, ptr %[[VAL_11]], align 8
1822
// CHECK-NEXT: store i32 999, ptr %[[VAL_13]], align 4
1923
// CHECK-NEXT: br label %[[VAL_14:.*]]

0 commit comments

Comments
 (0)