Skip to content

Commit 905d343

Browse files
committed
[mlir][OpenMP] - MLIR to LLVMIR translation support for delayed privatization of allocatables in omp.target ops
This PR adds support to translate the `private` clause from MLIR to LLVMIR when used on allocatables in the context of an `omp.target` op.
1 parent 7c12418 commit 905d343

File tree

12 files changed

+418
-78
lines changed

12 files changed

+418
-78
lines changed

flang/lib/Optimizer/OpenMP/MapsForPrivatizedSymbols.cpp

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -49,13 +49,6 @@ class MapsForPrivatizedSymbolsPass
4949
: public flangomp::impl::MapsForPrivatizedSymbolsPassBase<
5050
MapsForPrivatizedSymbolsPass> {
5151

52-
bool privatizerNeedsMap(omp::PrivateClauseOp &privatizer) {
53-
Region &allocRegion = privatizer.getAllocRegion();
54-
Value blockArg0 = allocRegion.getArgument(0);
55-
if (blockArg0.use_empty())
56-
return false;
57-
return true;
58-
}
5952
omp::MapInfoOp createMapInfo(Location loc, Value var,
6053
fir::FirOpBuilder &builder) {
6154
uint64_t mapTypeTo = static_cast<
@@ -134,7 +127,7 @@ class MapsForPrivatizedSymbolsPass
134127
omp::PrivateClauseOp privatizer =
135128
SymbolTable::lookupNearestSymbolFrom<omp::PrivateClauseOp>(
136129
targetOp, privatizerName);
137-
if (!privatizerNeedsMap(privatizer)) {
130+
if (!privatizer.needsMap()) {
138131
privVarMapIdx.push_back(-1);
139132
continue;
140133
}

llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6821,8 +6821,11 @@ static Expected<Function *> createOutlinedFunction(
68216821
OMPBuilder.ConstantAllocaRaiseCandidates.emplace_back(Func);
68226822

68236823
// Insert target deinit call in the device compilation pass.
6824-
llvm::OpenMPIRBuilder::InsertPointOrErrorTy AfterIP =
6825-
CBFunc(Builder.saveIP(), Builder.saveIP());
6824+
BasicBlock *OutlinedBodyBB =
6825+
splitBB(Builder, /*CreateBranch=*/true, "outlined.body");
6826+
llvm::OpenMPIRBuilder::InsertPointOrErrorTy AfterIP = CBFunc(
6827+
Builder.saveIP(),
6828+
OpenMPIRBuilder::InsertPointTy(OutlinedBodyBB, OutlinedBodyBB->begin()));
68266829
if (!AfterIP)
68276830
return AfterIP.takeError();
68286831
Builder.restoreIP(*AfterIP);

llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6358,7 +6358,13 @@ TEST_F(OpenMPIRBuilderTest, TargetRegionDevice) {
63586358
auto *Load2 = Load1->getNextNode();
63596359
EXPECT_TRUE(isa<LoadInst>(Load2));
63606360

6361-
auto *Value1 = Load2->getNextNode();
6361+
auto *OutlinedBlockBr = Load2->getNextNode();
6362+
EXPECT_TRUE(isa<BranchInst>(OutlinedBlockBr));
6363+
6364+
auto *OutlinedBlock = OutlinedBlockBr->getSuccessor(0);
6365+
EXPECT_EQ(OutlinedBlock->getName(), "outlined.body");
6366+
6367+
auto *Value1 = OutlinedBlock->getFirstNonPHI();
63626368
EXPECT_EQ(Value1, Value);
63636369
EXPECT_EQ(Value1->getNextNode(), TargetStore);
63646370
auto *Deinit = TargetStore->getNextNode();
@@ -6510,7 +6516,14 @@ TEST_F(OpenMPIRBuilderTest, ConstantAllocaRaise) {
65106516
EXPECT_EQ(UserCodeBlock->getName(), "user_code.entry");
65116517
auto *Load1 = UserCodeBlock->getFirstNonPHI();
65126518
EXPECT_TRUE(isa<LoadInst>(Load1));
6513-
auto *Load2 = Load1->getNextNode();
6519+
6520+
auto *OutlinedBlockBr = Load1->getNextNode();
6521+
EXPECT_TRUE(isa<BranchInst>(OutlinedBlockBr));
6522+
6523+
auto *OutlinedBlock = OutlinedBlockBr->getSuccessor(0);
6524+
EXPECT_EQ(OutlinedBlock->getName(), "outlined.body");
6525+
6526+
auto *Load2 = OutlinedBlock->getFirstNonPHI();
65146527
EXPECT_TRUE(isa<LoadInst>(Load2));
65156528
EXPECT_EQ(Load2, Value);
65166529
EXPECT_EQ(Load2->getNextNode(), TargetStore);

mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,14 @@ def PrivateClauseOp : OpenMP_Op<"private", [IsolatedFromAbove, RecipeInterface]>
135135
auto &region = getDeallocRegion();
136136
return region.empty() ? nullptr : region.getArgument(0);
137137
}
138+
139+
/// needsMap returns true if the value being privatized should additionally
140+
/// be mapped to the target region using a MapInfoOp. This is most common
141+
/// when an allocatable is privatized. In such cases, the descriptor is used
142+
/// in privatization and needs to be mapped on to the device.
143+
bool needsMap() {
144+
return !getAllocMoldArg().use_empty();
145+
}
138146
}];
139147

140148
let hasRegionVerifier = 1;

mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp

Lines changed: 146 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -299,10 +299,6 @@ static LogicalResult checkImplementationStatus(Operation &op) {
299299
if (privatizer.getDataSharingType() ==
300300
omp::DataSharingClauseType::FirstPrivate)
301301
result = todo("firstprivate");
302-
303-
if (!privatizer.getDeallocRegion().empty())
304-
result = op.emitError("not yet implemented: privatization of "
305-
"structures in omp.target operation");
306302
}
307303
}
308304
checkThreadLimit(op, result);
@@ -1290,6 +1286,43 @@ static LogicalResult allocAndInitializeReductionVars(
12901286
isByRef, deferredStores);
12911287
}
12921288

1289+
/// Return the llvm::Value * corresponding to the `privateVar` that
1290+
/// is being privatized. It isn't always as simple as looking up
1291+
/// moduleTranslation with privateVar. For instance, in case of
1292+
/// an allocatable, the descriptor for the allocatable is privatized.
1293+
/// This descriptor is mapped using an MapInfoOp. So, this function
1294+
/// will return a pointer to the llvm::Value corresponding to the
1295+
/// block argument for the mapped descriptor.
1296+
static llvm::Value *
1297+
findAssociatedValue(Value privateVar, llvm::IRBuilderBase &builder,
1298+
LLVM::ModuleTranslation &moduleTranslation,
1299+
omp::TargetOp targetOp = nullptr,
1300+
llvm::DenseMap<Value, int> *mappedPrivateVars = nullptr) {
1301+
if (mappedPrivateVars == nullptr || !mappedPrivateVars->contains(privateVar))
1302+
return moduleTranslation.lookupValue(privateVar);
1303+
1304+
int blockArgIndex = (*mappedPrivateVars)[privateVar];
1305+
Value blockArg = targetOp.getRegion().getArgument(blockArgIndex);
1306+
Type privVarType = privateVar.getType();
1307+
Type blockArgType = blockArg.getType();
1308+
assert(isa<LLVM::LLVMPointerType>(blockArgType) &&
1309+
"A block argument corresponding to a mapped var should have "
1310+
"!llvm.ptr type");
1311+
1312+
if (privVarType == blockArgType)
1313+
return moduleTranslation.lookupValue(blockArg);
1314+
1315+
// This typically happens when the privatized type is lowered from
1316+
// boxchar<KIND> and gets lowered to !llvm.struct<(ptr, i64)>. That is the
1317+
// struct/pair is passed by value. But, mapped values are passed only as
1318+
// pointers, so before we privatize, we must load the pointer.
1319+
if (!isa<LLVM::LLVMPointerType>(privVarType))
1320+
return builder.CreateLoad(moduleTranslation.convertType(privVarType),
1321+
moduleTranslation.lookupValue(blockArg));
1322+
1323+
return moduleTranslation.lookupValue(privateVar);
1324+
}
1325+
12931326
/// Allocate delayed private variables. Returns the basic block which comes
12941327
/// after all of these allocations. llvm::Value * for each of these private
12951328
/// variables are populated in llvmPrivateVars.
@@ -1300,7 +1333,9 @@ allocatePrivateVars(llvm::IRBuilderBase &builder,
13001333
MutableArrayRef<omp::PrivateClauseOp> privateDecls,
13011334
MutableArrayRef<mlir::Value> mlirPrivateVars,
13021335
llvm::SmallVectorImpl<llvm::Value *> &llvmPrivateVars,
1303-
const llvm::OpenMPIRBuilder::InsertPointTy &allocaIP) {
1336+
const llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
1337+
omp::TargetOp targetOp = nullptr,
1338+
llvm::DenseMap<Value, int> *mappedPrivateVars = nullptr) {
13041339
llvm::IRBuilderBase::InsertPointGuard guard(builder);
13051340
// Allocate private vars
13061341
llvm::BranchInst *allocaTerminator =
@@ -1330,7 +1365,8 @@ allocatePrivateVars(llvm::IRBuilderBase &builder,
13301365
Region &allocRegion = privDecl.getAllocRegion();
13311366

13321367
// map allocation region block argument
1333-
llvm::Value *nonPrivateVar = moduleTranslation.lookupValue(mlirPrivVar);
1368+
llvm::Value *nonPrivateVar = findAssociatedValue(
1369+
mlirPrivVar, builder, moduleTranslation, targetOp, mappedPrivateVars);
13341370
assert(nonPrivateVar);
13351371
moduleTranslation.mapValue(privDecl.getAllocMoldArg(), nonPrivateVar);
13361372

@@ -1345,6 +1381,7 @@ allocatePrivateVars(llvm::IRBuilderBase &builder,
13451381
} else {
13461382
builder.SetInsertPoint(privAllocBlock->getTerminator());
13471383
}
1384+
13481385
if (failed(inlineConvertOmpRegions(allocRegion, "omp.private.alloc",
13491386
builder, moduleTranslation, &phis)))
13501387
return llvm::createStringError(
@@ -3829,6 +3866,17 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
38293866
bool isTargetDevice = ompBuilder->Config.isTargetDevice();
38303867
auto parentFn = opInst.getParentOfType<LLVM::LLVMFuncOp>();
38313868
auto &targetRegion = targetOp.getRegion();
3869+
// Holds the private vars that have been mapped along with the block argument
3870+
// that corresponds to the MapInfoOp corresponding to the private var in
3871+
// question. So, for instance:
3872+
//
3873+
// %10 = omp.map.info var_ptr(%6#0 : !fir.ref<!fir.box<!fir.heap<i32>>>, ..)
3874+
// omp.target map_entries(%10 -> %arg0) private(@box.privatizer %6#0-> %arg1)
3875+
//
3876+
// Then, %10 has been created so that the descriptor can be used by the
3877+
// privatizer @box.privatizer on the device side. Here we'd record {%6#0, 0}
3878+
// in the mappedPrivateVars map.
3879+
llvm::DenseMap<Value, int> mappedPrivateVars;
38323880
DataLayout dl = DataLayout(opInst.getParentOfType<ModuleOp>());
38333881
SmallVector<Value> mapVars = targetOp.getMapVars();
38343882
ArrayRef<BlockArgument> mapBlockArgs =
@@ -3840,6 +3888,56 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
38403888
bool isOffloadEntry =
38413889
isTargetDevice || !ompBuilder->Config.TargetTriples.empty();
38423890

3891+
// For some private variables, the MapsForPrivatizedVariablesPass
3892+
// creates MapInfoOp instances. Go through the private variables and
3893+
// the mapped variables so that during codegeneration we are able
3894+
// to quickly look up the corresponding map variable, if any for each
3895+
// private variable.
3896+
if (!targetOp.getPrivateVars().empty() && !targetOp.getMapVars().empty()) {
3897+
auto argIface = llvm::cast<omp::BlockArgOpenMPOpInterface>(*targetOp);
3898+
OperandRange privateVars = targetOp.getPrivateVars();
3899+
std::optional<ArrayAttr> privateSyms = targetOp.getPrivateSyms();
3900+
std::optional<DenseI64ArrayAttr> privateMapIndices =
3901+
targetOp.getPrivateMapsAttr();
3902+
3903+
for (auto [privVarIdx, privVarSymPair] :
3904+
llvm::enumerate(llvm::zip_equal(privateVars, *privateSyms))) {
3905+
auto privVar = std::get<0>(privVarSymPair);
3906+
auto privSym = std::get<1>(privVarSymPair);
3907+
3908+
SymbolRefAttr privatizerName = llvm::cast<SymbolRefAttr>(privSym);
3909+
omp::PrivateClauseOp privatizer =
3910+
findPrivatizer(targetOp, privatizerName);
3911+
3912+
if (!privatizer.needsMap())
3913+
continue;
3914+
3915+
mlir::Value mappedValue =
3916+
targetOp.getMappedValueForPrivateVar(privVarIdx);
3917+
assert(mappedValue && "Expected to find mapped value for a privatized "
3918+
"variable that needs mapping");
3919+
3920+
// The MapInfoOp defining the map var isn't really needed later.
3921+
// So, we don't store it in any datastructure. Instead, we just
3922+
// do some sanity checks on it right now.
3923+
auto mapInfoOp = mappedValue.getDefiningOp<omp::MapInfoOp>();
3924+
Type varType = mapInfoOp.getVarType();
3925+
3926+
// Check #1: Check that the type of the private variable matches
3927+
// the type of the variable being mapped.
3928+
if (!isa<LLVM::LLVMPointerType>(privVar.getType()))
3929+
assert(
3930+
varType == privVar.getType() &&
3931+
"Type of private var doesn't match the type of the mapped value");
3932+
3933+
// Ok, only 1 sanity check for now.
3934+
// Record the index of the block argument corresponding to this
3935+
// mapvar.
3936+
mappedPrivateVars.insert({privVar, argIface.getMapBlockArgsStart() +
3937+
(*privateMapIndices)[privVarIdx]});
3938+
}
3939+
}
3940+
38433941
using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
38443942
auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP)
38453943
-> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
@@ -3859,7 +3957,6 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
38593957
attr.isStringAttribute())
38603958
llvmOutlinedFn->addFnAttr(attr);
38613959

3862-
builder.restoreIP(codeGenIP);
38633960
for (auto [arg, mapOp] : llvm::zip_equal(mapBlockArgs, mapVars)) {
38643961
auto mapInfoOp = cast<omp::MapInfoOp>(mapOp.getDefiningOp());
38653962
llvm::Value *mapOpValue =
@@ -3869,50 +3966,53 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
38693966

38703967
// Do privatization after moduleTranslation has already recorded
38713968
// mapped values.
3872-
if (!targetOp.getPrivateVars().empty()) {
3873-
builder.restoreIP(allocaIP);
3874-
3875-
OperandRange privateVars = targetOp.getPrivateVars();
3876-
std::optional<ArrayAttr> privateSyms = targetOp.getPrivateSyms();
3877-
MutableArrayRef<BlockArgument> privateBlockArgs =
3878-
cast<omp::BlockArgOpenMPOpInterface>(opInst).getPrivateBlockArgs();
3879-
3880-
for (auto [privVar, privatizerNameAttr, privBlockArg] :
3881-
llvm::zip_equal(privateVars, *privateSyms, privateBlockArgs)) {
3882-
3883-
SymbolRefAttr privSym = cast<SymbolRefAttr>(privatizerNameAttr);
3884-
omp::PrivateClauseOp privatizer = findPrivatizer(&opInst, privSym);
3885-
assert(privatizer.getDataSharingType() !=
3886-
omp::DataSharingClauseType::FirstPrivate &&
3887-
privatizer.getDeallocRegion().empty() &&
3888-
"unsupported privatizer");
3889-
moduleTranslation.mapValue(privatizer.getAllocMoldArg(),
3890-
moduleTranslation.lookupValue(privVar));
3891-
Region &allocRegion = privatizer.getAllocRegion();
3892-
SmallVector<llvm::Value *, 1> yieldedValues;
3893-
if (failed(inlineConvertOmpRegions(
3894-
allocRegion, "omp.targetop.privatizer", builder,
3895-
moduleTranslation, &yieldedValues))) {
3896-
return llvm::createStringError(
3897-
"failed to inline `alloc` region of `omp.private`");
3898-
}
3899-
assert(yieldedValues.size() == 1);
3900-
moduleTranslation.mapValue(privBlockArg, yieldedValues.front());
3901-
moduleTranslation.forgetMapping(allocRegion);
3902-
builder.restoreIP(builder.saveIP());
3903-
}
3904-
}
3969+
MutableArrayRef<BlockArgument> privateBlockArgs =
3970+
cast<omp::BlockArgOpenMPOpInterface>(opInst).getPrivateBlockArgs();
3971+
SmallVector<mlir::Value> mlirPrivateVars;
3972+
SmallVector<llvm::Value *> llvmPrivateVars;
3973+
SmallVector<omp::PrivateClauseOp> privateDecls;
3974+
mlirPrivateVars.reserve(privateBlockArgs.size());
3975+
llvmPrivateVars.reserve(privateBlockArgs.size());
3976+
collectPrivatizationDecls(targetOp, privateDecls);
3977+
for (mlir::Value privateVar : targetOp.getPrivateVars())
3978+
mlirPrivateVars.push_back(privateVar);
3979+
3980+
llvm::Expected<llvm::BasicBlock *> afterAllocas =
3981+
allocatePrivateVars(builder, moduleTranslation, privateBlockArgs,
3982+
privateDecls, mlirPrivateVars, llvmPrivateVars,
3983+
allocaIP, targetOp, &mappedPrivateVars);
3984+
3985+
if (handleError(afterAllocas, *targetOp).failed())
3986+
return llvm::make_error<PreviouslyReportedError>();
39053987

3988+
SmallVector<Region *> privateCleanupRegions;
3989+
llvm::transform(privateDecls, std::back_inserter(privateCleanupRegions),
3990+
[](omp::PrivateClauseOp privatizer) {
3991+
return &privatizer.getDeallocRegion();
3992+
});
3993+
3994+
builder.restoreIP(codeGenIP);
39063995
llvm::Expected<llvm::BasicBlock *> exitBlock = convertOmpOpRegions(
39073996
targetRegion, "omp.target", builder, moduleTranslation);
3997+
39083998
if (!exitBlock)
39093999
return exitBlock.takeError();
39104000

39114001
builder.SetInsertPoint(*exitBlock);
3912-
return builder.saveIP();
4002+
if (!privateCleanupRegions.empty()) {
4003+
if (failed(inlineOmpRegionCleanup(
4004+
privateCleanupRegions, llvmPrivateVars, moduleTranslation,
4005+
builder, "omp.targetop.private.cleanup",
4006+
/*shouldLoadCleanupRegionArg=*/false))) {
4007+
return llvm::createStringError(
4008+
"failed to inline `dealloc` region of `omp.private` "
4009+
"op in the target region");
4010+
}
4011+
}
4012+
4013+
return InsertPointTy(exitBlock.get(), exitBlock.get()->end());
39134014
};
39144015

3915-
llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
39164016
StringRef parentName = parentFn.getName();
39174017

39184018
llvm::TargetRegionEntryInfo entryInfo;
@@ -3923,9 +4023,6 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
39234023
int32_t defaultValTeams = -1;
39244024
int32_t defaultValThreads = 0;
39254025

3926-
llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
3927-
findAllocaInsertPoint(builder, moduleTranslation);
3928-
39294026
MapInfoData mapData;
39304027
collectMapDataFromMapOperands(mapData, mapVars, moduleTranslation, dl,
39314028
builder);
@@ -3973,6 +4070,10 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
39734070
buildDependData(targetOp.getDependKinds(), targetOp.getDependVars(),
39744071
moduleTranslation, dds);
39754072

4073+
llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
4074+
findAllocaInsertPoint(builder, moduleTranslation);
4075+
llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
4076+
39764077
llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
39774078
moduleTranslation.getOpenMPBuilder()->createTarget(
39784079
ompLoc, isOffloadEntry, allocaIP, builder.saveIP(), entryInfo,

mlir/test/Target/LLVMIR/omptarget-byref-bycopy-generation-device.mlir

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,12 @@ module attributes {omp.is_target_device = true} {
3333

3434
// CHECK: user_code.entry: ; preds = %entry
3535
// CHECK: %[[LOAD_BYREF:.*]] = load ptr, ptr %[[ALLOCA_BYREF]], align 8
36+
// CHECK: br label %outlined.body
37+
38+
// CHECK: outlined.body:
3639
// CHECK: br label %omp.target
3740

38-
// CHECK: omp.target: ; preds = %user_code.entry
41+
// CHECK: omp.target:
3942
// CHECK: %[[VAL_LOAD_BYCOPY:.*]] = load i32, ptr %[[ALLOCA_BYCOPY]], align 4
4043
// CHECK: store i32 %[[VAL_LOAD_BYCOPY]], ptr %[[LOAD_BYREF]], align 4
4144
// CHECK: br label %omp.region.cont

mlir/test/Target/LLVMIR/omptarget-declare-target-llvm-device.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ module attributes {omp.is_target_device = true} {
1717
llvm.func @_QQmain() attributes {} {
1818
%0 = llvm.mlir.addressof @_QMtest_0Esp : !llvm.ptr
1919

20-
// CHECK-DAG: omp.target: ; preds = %user_code.entry
20+
// CHECK-DAG: omp.target: ; preds = %outlined.body
2121
// CHECK-DAG: %[[V:.*]] = load ptr, ptr @_QMtest_0Esp_decl_tgt_ref_ptr, align 8
2222
// CHECK-DAG: store i32 1, ptr %[[V]], align 4
2323
// CHECK-DAG: br label %omp.region.cont

0 commit comments

Comments
 (0)