Skip to content

Commit 9062b45

Browse files
committed
Make convertOmpTarget use allocatePrivateVars.
Allows more code reuse by generalizing `allocatePrivateVars` to serve `target` op conversion.
1 parent 74bcf0a commit 9062b45

File tree

5 files changed

+91
-94
lines changed

5 files changed

+91
-94
lines changed

llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6807,8 +6807,11 @@ static Expected<Function *> createOutlinedFunction(
68076807
OMPBuilder.ConstantAllocaRaiseCandidates.emplace_back(Func);
68086808

68096809
// Insert target deinit call in the device compilation pass.
6810-
llvm::OpenMPIRBuilder::InsertPointOrErrorTy AfterIP =
6811-
CBFunc(Builder.saveIP(), Builder.saveIP());
6810+
BasicBlock *OutlinedBodyBB =
6811+
splitBB(Builder, /*CreateBranch=*/true, "outlined.body");
6812+
llvm::OpenMPIRBuilder::InsertPointOrErrorTy AfterIP = CBFunc(
6813+
Builder.saveIP(),
6814+
OpenMPIRBuilder::InsertPointTy(OutlinedBodyBB, OutlinedBodyBB->begin()));
68126815
if (!AfterIP)
68136816
return AfterIP.takeError();
68146817
Builder.restoreIP(*AfterIP);

mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp

Lines changed: 76 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -1242,6 +1242,44 @@ static LogicalResult allocAndInitializeReductionVars(
12421242
return success();
12431243
}
12441244

1245+
/// Return the llvm::Value * corresponding to the `privateVar` that
1246+
/// is being privatized. It isn't always as simple as looking up
1247+
/// moduleTranslation with privateVar. For instance, in case of
1248+
/// an allocatable, the descriptor for the allocatable is privatized.
1249+
/// This descriptor is mapped using an MapInfoOp. So, this function
1250+
/// will return a pointer to the llvm::Value corresponding to the
1251+
/// block argument for the mapped descriptor.
1252+
static llvm::Value *
1253+
findAssociatedValue(Value privateVar, llvm::IRBuilderBase &builder,
1254+
LLVM::ModuleTranslation &moduleTranslation,
1255+
omp::TargetOp targetOp = nullptr,
1256+
llvm::DenseMap<Value, int> *mappedPrivateVars = nullptr) {
1257+
if (mappedPrivateVars != nullptr && mappedPrivateVars->contains(privateVar)) {
1258+
int blockArgIndex = (*mappedPrivateVars)[privateVar];
1259+
Value blockArg = targetOp.getRegion().getArgument(blockArgIndex);
1260+
Type privVarType = privateVar.getType();
1261+
Type blockArgType = blockArg.getType();
1262+
assert(isa<LLVM::LLVMPointerType>(blockArgType) &&
1263+
"A block argument corresponding to a mapped var should have "
1264+
"!llvm.ptr type");
1265+
1266+
if (privVarType == blockArg.getType())
1267+
return moduleTranslation.lookupValue(blockArg);
1268+
1269+
if (!isa<LLVM::LLVMPointerType>(privVarType)) {
1270+
// This typically happens when the privatized type is lowered from
1271+
// boxchar<KIND> and gets lowered to !llvm.struct<(ptr, i64)>. That is the
1272+
// struct/pair is passed by value. But, mapped values are passed only as
1273+
// pointers, so before we privatize, we must load the pointer.
1274+
llvm::Value *load =
1275+
builder.CreateLoad(moduleTranslation.convertType(privVarType),
1276+
moduleTranslation.lookupValue(blockArg));
1277+
return load;
1278+
}
1279+
}
1280+
return moduleTranslation.lookupValue(privateVar);
1281+
}
1282+
12451283
/// Allocate delayed private variables. Returns the basic block which comes
12461284
/// after all of these allocations. llvm::Value * for each of these private
12471285
/// variables are populated in llvmPrivateVars.
@@ -1252,7 +1290,9 @@ allocatePrivateVars(llvm::IRBuilderBase &builder,
12521290
MutableArrayRef<omp::PrivateClauseOp> privateDecls,
12531291
MutableArrayRef<mlir::Value> mlirPrivateVars,
12541292
llvm::SmallVectorImpl<llvm::Value *> &llvmPrivateVars,
1255-
const llvm::OpenMPIRBuilder::InsertPointTy &allocaIP) {
1293+
const llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
1294+
omp::TargetOp targetOp = nullptr,
1295+
llvm::DenseMap<Value, int> *mappedPrivateVars = nullptr) {
12561296
// Allocate private vars
12571297
llvm::BranchInst *allocaTerminator =
12581298
llvm::cast<llvm::BranchInst>(allocaIP.getBlock()->getTerminator());
@@ -1281,7 +1321,8 @@ allocatePrivateVars(llvm::IRBuilderBase &builder,
12811321
Region &allocRegion = privDecl.getAllocRegion();
12821322

12831323
// map allocation region block argument
1284-
llvm::Value *nonPrivateVar = moduleTranslation.lookupValue(mlirPrivVar);
1324+
llvm::Value *nonPrivateVar = findAssociatedValue(
1325+
mlirPrivVar, builder, moduleTranslation, targetOp, mappedPrivateVars);
12851326
assert(nonPrivateVar);
12861327
moduleTranslation.mapValue(privDecl.getAllocMoldArg(), nonPrivateVar);
12871328

@@ -1296,6 +1337,7 @@ allocatePrivateVars(llvm::IRBuilderBase &builder,
12961337
} else {
12971338
builder.SetInsertPoint(privAllocBlock->getTerminator());
12981339
}
1340+
12991341
if (failed(inlineConvertOmpRegions(allocRegion, "omp.private.alloc",
13001342
builder, moduleTranslation, &phis)))
13011343
return llvm::createStringError(
@@ -3806,43 +3848,6 @@ createDeviceArgumentAccessor(MapInfoData &mapData, llvm::Argument &arg,
38063848
return builder.saveIP();
38073849
}
38083850

3809-
/// Return the llvm::Value * corresponding to the `privateVar` that
3810-
/// is being privatized. It isn't always as simple as looking up
3811-
/// moduleTranslation with privateVar. For instance, in case of
3812-
/// an allocatable, the descriptor for the allocatable is privatized.
3813-
/// This descriptor is mapped using an MapInfoOp. So, this function
3814-
/// will return a pointer to the llvm::Value corresponding to the
3815-
/// block argument for the mapped descriptor.
3816-
static llvm::Value *
3817-
findHostAssociatedValue(Value privateVar, omp::TargetOp targetOp,
3818-
llvm::DenseMap<Value, int> &mappedPrivateVars,
3819-
llvm::IRBuilderBase &builder,
3820-
LLVM::ModuleTranslation &moduleTranslation) {
3821-
if (mappedPrivateVars.contains(privateVar)) {
3822-
int blockArgIndex = mappedPrivateVars[privateVar];
3823-
Value blockArg = targetOp.getRegion().getArgument(blockArgIndex);
3824-
Type privVarType = privateVar.getType();
3825-
Type blockArgType = blockArg.getType();
3826-
assert(isa<LLVM::LLVMPointerType>(blockArgType) &&
3827-
"A block argument corresponding to a mapped var should have "
3828-
"!llvm.ptr type");
3829-
3830-
if (privVarType == blockArg.getType())
3831-
return moduleTranslation.lookupValue(blockArg);
3832-
3833-
if (!isa<LLVM::LLVMPointerType>(privVarType)) {
3834-
// This typically happens when the privatized type is lowered from
3835-
// boxchar<KIND> and gets lowered to !llvm.struct<(ptr, i64)>. That is the
3836-
// struct/pair is passed by value. But, mapped values are passed only as
3837-
// pointers, so before we privatize, we must load the pointer.
3838-
llvm::Value *load =
3839-
builder.CreateLoad(moduleTranslation.convertType(privVarType),
3840-
moduleTranslation.lookupValue(blockArg));
3841-
return load;
3842-
}
3843-
}
3844-
return moduleTranslation.lookupValue(privateVar);
3845-
}
38463851
static LogicalResult
38473852
convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
38483853
LLVM::ModuleTranslation &moduleTranslation) {
@@ -3946,68 +3951,49 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
39463951
attr.isStringAttribute())
39473952
llvmOutlinedFn->addFnAttr(attr);
39483953

3949-
builder.restoreIP(codeGenIP);
39503954
for (auto [arg, mapOp] : llvm::zip_equal(mapBlockArgs, mapVars)) {
39513955
auto mapInfoOp = cast<omp::MapInfoOp>(mapOp.getDefiningOp());
39523956
llvm::Value *mapOpValue =
39533957
moduleTranslation.lookupValue(mapInfoOp.getVarPtr());
39543958
moduleTranslation.mapValue(arg, mapOpValue);
39553959
}
3960+
39563961
// Do privatization after moduleTranslation has already recorded
39573962
// mapped values.
3963+
MutableArrayRef<BlockArgument> privateBlockArgs =
3964+
cast<omp::BlockArgOpenMPOpInterface>(opInst).getPrivateBlockArgs();
3965+
SmallVector<mlir::Value> mlirPrivateVars;
39583966
SmallVector<llvm::Value *> llvmPrivateVars;
3967+
SmallVector<omp::PrivateClauseOp> privateDecls;
3968+
mlirPrivateVars.reserve(privateBlockArgs.size());
3969+
llvmPrivateVars.reserve(privateBlockArgs.size());
3970+
collectPrivatizationDecls(targetOp, privateDecls);
3971+
for (mlir::Value privateVar : targetOp.getPrivateVars())
3972+
mlirPrivateVars.push_back(privateVar);
3973+
3974+
llvm::Expected<llvm::BasicBlock *> afterAllocas =
3975+
allocatePrivateVars(builder, moduleTranslation, privateBlockArgs,
3976+
privateDecls, mlirPrivateVars, llvmPrivateVars,
3977+
allocaIP, targetOp, &mappedPrivateVars);
3978+
3979+
if (handleError(afterAllocas, *targetOp).failed())
3980+
return llvm::make_error<PreviouslyReportedError>();
3981+
39593982
SmallVector<Region *> privateCleanupRegions;
3960-
if (!targetOp.getPrivateVars().empty()) {
3961-
builder.restoreIP(allocaIP);
3962-
3963-
OperandRange privateVars = targetOp.getPrivateVars();
3964-
std::optional<ArrayAttr> privateSyms = targetOp.getPrivateSyms();
3965-
MutableArrayRef<BlockArgument> privateBlockArgs =
3966-
cast<omp::BlockArgOpenMPOpInterface>(opInst).getPrivateBlockArgs();
3967-
3968-
for (auto [privVar, privatizerNameAttr, privBlockArg] :
3969-
llvm::zip_equal(privateVars, *privateSyms, privateBlockArgs)) {
3970-
3971-
SymbolRefAttr privSym = cast<SymbolRefAttr>(privatizerNameAttr);
3972-
omp::PrivateClauseOp privatizer = findPrivatizer(&opInst, privSym);
3973-
assert(privatizer.getDataSharingType() !=
3974-
omp::DataSharingClauseType::FirstPrivate &&
3975-
"unsupported privatizer");
3976-
Region &allocRegion = privatizer.getAllocRegion();
3977-
BlockArgument allocRegionArg = allocRegion.getArgument(0);
3978-
moduleTranslation.mapValue(
3979-
allocRegionArg,
3980-
findHostAssociatedValue(privVar, targetOp, mappedPrivateVars,
3981-
builder, moduleTranslation));
3982-
SmallVector<llvm::Value *, 1> yieldedValues;
3983-
if (failed(inlineConvertOmpRegions(
3984-
allocRegion, "omp.targetop.privatizer", builder,
3985-
moduleTranslation, &yieldedValues))) {
3986-
return llvm::createStringError(
3987-
"failed to inline `alloc` region of `omp.private`");
3988-
}
3989-
assert(yieldedValues.size() == 1);
3990-
llvm::Value *llvmReplacementValue = yieldedValues.front();
3991-
moduleTranslation.mapValue(privBlockArg, llvmReplacementValue);
3992-
if (!privatizer.getDeallocRegion().empty()) {
3993-
llvmPrivateVars.push_back(llvmReplacementValue);
3994-
privateCleanupRegions.push_back(&privatizer.getDeallocRegion());
3995-
}
3996-
moduleTranslation.forgetMapping(allocRegion);
3997-
builder.restoreIP(builder.saveIP());
3998-
}
3999-
}
3983+
llvm::transform(privateDecls, std::back_inserter(privateCleanupRegions),
3984+
[](omp::PrivateClauseOp privatizer) {
3985+
return &privatizer.getDeallocRegion();
3986+
});
40003987

3988+
builder.restoreIP(codeGenIP);
40013989
llvm::Expected<llvm::BasicBlock *> exitBlock = convertOmpOpRegions(
40023990
targetRegion, "omp.target", builder, moduleTranslation);
3991+
40033992
if (!exitBlock)
40043993
return exitBlock.takeError();
40053994

40063995
builder.SetInsertPoint(*exitBlock);
4007-
if (!llvmPrivateVars.empty()) {
4008-
assert(llvmPrivateVars.size() == privateCleanupRegions.size() &&
4009-
"Number of private variables needing cleanup not equal to number"
4010-
"of privatizers with dealloc regions");
3996+
if (!privateCleanupRegions.empty()) {
40113997
if (failed(inlineOmpRegionCleanup(
40123998
privateCleanupRegions, llvmPrivateVars, moduleTranslation,
40133999
builder, "omp.targetop.private.cleanup",
@@ -4017,10 +4003,10 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
40174003
"op in the target region");
40184004
}
40194005
}
4020-
return builder.saveIP();
4006+
4007+
return InsertPointTy(exitBlock.get(), exitBlock.get()->end());
40214008
};
40224009

4023-
llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
40244010
StringRef parentName = parentFn.getName();
40254011

40264012
llvm::TargetRegionEntryInfo entryInfo;
@@ -4031,9 +4017,6 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
40314017
int32_t defaultValTeams = -1;
40324018
int32_t defaultValThreads = 0;
40334019

4034-
llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
4035-
findAllocaInsertPoint(builder, moduleTranslation);
4036-
40374020
MapInfoData mapData;
40384021
collectMapDataFromMapOperands(mapData, mapVars, moduleTranslation, dl,
40394022
builder);
@@ -4081,6 +4064,10 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
40814064
buildDependData(targetOp.getDependKinds(), targetOp.getDependVars(),
40824065
moduleTranslation, dds);
40834066

4067+
llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
4068+
findAllocaInsertPoint(builder, moduleTranslation);
4069+
llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
4070+
40844071
llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
40854072
moduleTranslation.getOpenMPBuilder()->createTarget(
40864073
ompLoc, isOffloadEntry, allocaIP, builder.saveIP(), entryInfo,

mlir/test/Target/LLVMIR/omptarget-byref-bycopy-generation-device.mlir

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,12 @@ module attributes {omp.is_target_device = true} {
3333

3434
// CHECK: user_code.entry: ; preds = %entry
3535
// CHECK: %[[LOAD_BYREF:.*]] = load ptr, ptr %[[ALLOCA_BYREF]], align 8
36+
// CHECK: br label %outlined.body
37+
38+
// CHECK: outlined.body:
3639
// CHECK: br label %omp.target
3740

38-
// CHECK: omp.target: ; preds = %user_code.entry
41+
// CHECK: omp.target:
3942
// CHECK: %[[VAL_LOAD_BYCOPY:.*]] = load i32, ptr %[[ALLOCA_BYCOPY]], align 4
4043
// CHECK: store i32 %[[VAL_LOAD_BYCOPY]], ptr %[[LOAD_BYREF]], align 4
4144
// CHECK: br label %omp.region.cont

mlir/test/Target/LLVMIR/omptarget-declare-target-llvm-device.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ module attributes {omp.is_target_device = true} {
1717
llvm.func @_QQmain() attributes {} {
1818
%0 = llvm.mlir.addressof @_QMtest_0Esp : !llvm.ptr
1919

20-
// CHECK-DAG: omp.target: ; preds = %user_code.entry
20+
// CHECK-DAG: omp.target: ; preds = %outlined.body
2121
// CHECK-DAG: %[[V:.*]] = load ptr, ptr @_QMtest_0Esp_decl_tgt_ref_ptr, align 8
2222
// CHECK-DAG: store i32 1, ptr %[[V]], align 4
2323
// CHECK-DAG: br label %omp.region.cont

mlir/test/Target/LLVMIR/openmp-target-use-device-nested.mlir

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,11 @@
1313
// CHECK: user_code.entry: ; preds = %[[VAL_10:.*]]
1414
// CHECK-NEXT: %[[VAL_11:.*]] = load ptr, ptr %[[VAL_3]], align 8
1515
// CHECK-NEXT: br label %[[VAL_12:.*]]
16-
// CHECK: omp.target: ; preds = %[[VAL_8]]
16+
17+
// CHECK: [[VAL_12]]:
18+
// CHECK-NEXT: br label %[[TARGET_REG_ENTRY:.*]]
19+
20+
// CHECK: [[TARGET_REG_ENTRY]]: ; preds = %[[VAL_12]]
1721
// CHECK-NEXT: %[[VAL_13:.*]] = load ptr, ptr %[[VAL_11]], align 8
1822
// CHECK-NEXT: store i32 999, ptr %[[VAL_13]], align 4
1923
// CHECK-NEXT: br label %[[VAL_14:.*]]

0 commit comments

Comments
 (0)