Skip to content

Commit cdf4af1

Browse files
committed
[OpenMP][Flang] Emit default declare mappers implicitly for derived types
This patch adds support to emit default declare mappers for implicit mapping of derived types when not supplied by user. This especially helps tackle mapping of allocatables of derived types. This supports nested derived types as well.
1 parent 63d6e3e commit cdf4af1

File tree

8 files changed

+325
-46
lines changed

8 files changed

+325
-46
lines changed

flang/lib/Lower/OpenMP/ClauseProcessor.cpp

Lines changed: 99 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include "flang/Lower/OpenMP/Clauses.h"
1818
#include "flang/Lower/PFTBuilder.h"
1919
#include "flang/Lower/Support/ReductionProcessor.h"
20+
#include "flang/Optimizer/Dialect/FIRType.h"
2021
#include "flang/Parser/tools.h"
2122
#include "flang/Semantics/tools.h"
2223
#include "flang/Utils/OpenMP.h"
@@ -1223,26 +1224,77 @@ void ClauseProcessor::processMapObjects(
12231224
llvm::StringRef mapperIdNameRef) const {
12241225
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
12251226

1226-
auto getDefaultMapperID = [&](const omp::Object &object,
1227-
std::string &mapperIdName) {
1228-
if (!mlir::isa<mlir::omp::DeclareMapperOp>(
1229-
firOpBuilder.getRegion().getParentOp())) {
1230-
const semantics::DerivedTypeSpec *typeSpec = nullptr;
1227+
auto getSymbolDerivedType = [](const semantics::Symbol &symbol)
1228+
-> const semantics::DerivedTypeSpec * {
1229+
const semantics::Symbol &ultimate = symbol.GetUltimate();
1230+
if (const semantics::DeclTypeSpec *declType = ultimate.GetType())
1231+
if (const auto *derived = declType->AsDerived())
1232+
return derived;
1233+
return nullptr;
1234+
};
1235+
1236+
auto addImplicitmapper = [&](const omp::Object &object,
1237+
std::string &mapperIdName,
1238+
bool allowGenerate) -> mlir::FlatSymbolRefAttr {
1239+
if (mapperIdName.empty())
1240+
return mlir::FlatSymbolRefAttr();
1241+
1242+
bool symbolExists = converter.getModuleOp().lookupSymbol(mapperIdName);
1243+
if (!symbolExists && !allowGenerate)
1244+
return mlir::FlatSymbolRefAttr();
12311245

1232-
if (object.sym()->owner().IsDerivedType())
1246+
auto getOrCreateMapperAttr = [&]() -> mlir::FlatSymbolRefAttr {
1247+
if (symbolExists)
1248+
return mlir::FlatSymbolRefAttr::get(&converter.getMLIRContext(),
1249+
mapperIdName);
1250+
1251+
const semantics::DerivedTypeSpec *typeSpec =
1252+
getSymbolDerivedType(*object.sym());
1253+
if (!typeSpec && object.sym()->owner().IsDerivedType())
12331254
typeSpec = object.sym()->owner().derivedTypeSpec();
1234-
else if (object.sym()->GetType() &&
1235-
object.sym()->GetType()->category() ==
1236-
semantics::DeclTypeSpec::TypeDerived)
1237-
typeSpec = &object.sym()->GetType()->derivedTypeSpec();
12381255

1239-
if (typeSpec) {
1256+
if (!typeSpec)
1257+
return mlir::FlatSymbolRefAttr();
1258+
1259+
mlir::Type type = converter.genType(*typeSpec);
1260+
auto recordType = mlir::dyn_cast<fir::RecordType>(type);
1261+
if (!recordType)
1262+
return mlir::FlatSymbolRefAttr();
1263+
1264+
return getOrGenImplicitDefaultDeclareMapper(converter, clauseLocation,
1265+
recordType, mapperIdName);
1266+
};
1267+
1268+
mlir::FlatSymbolRefAttr mapperAttr = getOrCreateMapperAttr();
1269+
if (!mapperAttr)
1270+
return mlir::FlatSymbolRefAttr();
1271+
1272+
return mapperAttr;
1273+
};
1274+
1275+
auto getDefaultMapperID = [&](const semantics::DerivedTypeSpec *typeSpec,
1276+
std::string &mapperIdName) {
1277+
mapperIdName.clear();
1278+
if (!mlir::isa<mlir::omp::DeclareMapperOp>(
1279+
firOpBuilder.getRegion().getParentOp()) &&
1280+
typeSpec) {
1281+
mapperIdName =
1282+
typeSpec->name().ToString() + llvm::omp::OmpDefaultMapperName;
1283+
if (auto *sym = converter.getCurrentScope().FindSymbol(mapperIdName)) {
1284+
mapperIdName = converter.mangleName(mapperIdName, sym->owner());
1285+
} else {
12401286
mapperIdName =
1241-
typeSpec->name().ToString() + llvm::omp::OmpDefaultMapperName;
1242-
if (auto *sym = converter.getCurrentScope().FindSymbol(mapperIdName))
1243-
mapperIdName = converter.mangleName(mapperIdName, sym->owner());
1287+
converter.mangleName(mapperIdName, *typeSpec->GetScope());
12441288
}
12451289
}
1290+
1291+
// Make sure we don't return a mapper to self.
1292+
llvm::StringRef parentOpName;
1293+
if (auto declMapOp = mlir::dyn_cast<mlir::omp::DeclareMapperOp>(
1294+
firOpBuilder.getRegion().getParentOp()))
1295+
parentOpName = declMapOp.getSymName();
1296+
if (mapperIdName == parentOpName)
1297+
mapperIdName.clear();
12461298
};
12471299

12481300
// Create the mapper symbol from its name, if specified.
@@ -1251,8 +1303,13 @@ void ClauseProcessor::processMapObjects(
12511303
mapperIdNameRef != "__implicit_mapper") {
12521304
std::string mapperIdName = mapperIdNameRef.str();
12531305
const omp::Object &object = objects.front();
1254-
if (mapperIdNameRef == "default")
1255-
getDefaultMapperID(object, mapperIdName);
1306+
if (mapperIdNameRef == "default") {
1307+
const semantics::DerivedTypeSpec *typeSpec =
1308+
getSymbolDerivedType(*object.sym());
1309+
if (!typeSpec && object.sym()->owner().IsDerivedType())
1310+
typeSpec = object.sym()->owner().derivedTypeSpec();
1311+
getDefaultMapperID(typeSpec, mapperIdName);
1312+
}
12561313
assert(converter.getModuleOp().lookupSymbol(mapperIdName) &&
12571314
"mapper not found");
12581315
mapperId =
@@ -1290,13 +1347,33 @@ void ClauseProcessor::processMapObjects(
12901347
}
12911348
}
12921349

1350+
const semantics::DerivedTypeSpec *objectTypeSpec =
1351+
getSymbolDerivedType(*object.sym());
1352+
if (!objectTypeSpec && object.sym()->owner().IsDerivedType())
1353+
objectTypeSpec = object.sym()->owner().derivedTypeSpec();
1354+
12931355
if (mapperIdNameRef == "__implicit_mapper") {
1294-
std::string mapperIdName;
1295-
getDefaultMapperID(object, mapperIdName);
1296-
mapperId = converter.getModuleOp().lookupSymbol(mapperIdName)
1297-
? mlir::FlatSymbolRefAttr::get(&converter.getMLIRContext(),
1298-
mapperIdName)
1299-
: mlir::FlatSymbolRefAttr();
1356+
if (parentObj.has_value()) {
1357+
mapperId = mlir::FlatSymbolRefAttr();
1358+
} else {
1359+
std::string mapperIdName;
1360+
getDefaultMapperID(objectTypeSpec, mapperIdName);
1361+
bool needsDefaultMapper =
1362+
(objectTypeSpec &&
1363+
requiresImplicitDefaultDeclareMapper(*objectTypeSpec)) ||
1364+
semantics::IsAllocatableOrObjectPointer(object.sym());
1365+
bool containsDelete = (mapTypeBits & mlir::omp::ClauseMapFlags::del) !=
1366+
mlir::omp::ClauseMapFlags::none;
1367+
bool mapperExists = !mapperIdName.empty() &&
1368+
converter.getModuleOp().lookupSymbol(mapperIdName);
1369+
if ((needsDefaultMapper || mapperExists) && !mapperIdName.empty() &&
1370+
!containsDelete)
1371+
mapperId = addImplicitmapper(object, mapperIdName,
1372+
/*allowGenerate=*/needsDefaultMapper &&
1373+
!mapperExists);
1374+
else
1375+
mapperId = mlir::FlatSymbolRefAttr();
1376+
}
13001377
}
13011378

13021379
// Explicit map captures are captured ByRef by default,

flang/lib/Lower/OpenMP/OpenMP.cpp

Lines changed: 40 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2578,18 +2578,6 @@ genTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
25782578
fir::ExtendedValue dataExv = converter.getSymbolExtendedValue(sym);
25792579
name << sym.name().ToString();
25802580

2581-
mlir::FlatSymbolRefAttr mapperId;
2582-
if (sym.GetType()->category() == semantics::DeclTypeSpec::TypeDerived) {
2583-
auto &typeSpec = sym.GetType()->derivedTypeSpec();
2584-
std::string mapperIdName =
2585-
typeSpec.name().ToString() + llvm::omp::OmpDefaultMapperName;
2586-
if (auto *sym = converter.getCurrentScope().FindSymbol(mapperIdName))
2587-
mapperIdName = converter.mangleName(mapperIdName, sym->owner());
2588-
if (converter.getModuleOp().lookupSymbol(mapperIdName))
2589-
mapperId = mlir::FlatSymbolRefAttr::get(&converter.getMLIRContext(),
2590-
mapperIdName);
2591-
}
2592-
25932581
fir::factory::AddrAndBoundsInfo info =
25942582
Fortran::lower::getDataOperandBaseAddr(
25952583
converter, firOpBuilder, sym.GetUltimate(),
@@ -2609,6 +2597,46 @@ genTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
26092597
mapFlagAndKind = getImplicitMapTypeAndKind(
26102598
firOpBuilder, converter, defaultMaps, eleType, loc, sym);
26112599

2600+
mlir::FlatSymbolRefAttr mapperId;
2601+
if (defaultMaps.empty()) {
2602+
const semantics::DerivedTypeSpec *typeSpec =
2603+
sym.GetType() ? sym.GetType()->AsDerived() : nullptr;
2604+
if (typeSpec) {
2605+
auto getDefaultMapperName = [&]() -> std::string {
2606+
std::string mapperIdName =
2607+
typeSpec->name().ToString() + llvm::omp::OmpDefaultMapperName;
2608+
if (auto *mapperSym =
2609+
converter.getCurrentScope().FindSymbol(mapperIdName))
2610+
mapperIdName =
2611+
converter.mangleName(mapperIdName, mapperSym->owner());
2612+
else
2613+
mapperIdName =
2614+
converter.mangleName(mapperIdName, *typeSpec->GetScope());
2615+
return mapperIdName;
2616+
};
2617+
2618+
std::string mapperIdName = getDefaultMapperName();
2619+
if (!mapperIdName.empty()) {
2620+
bool mapperExists =
2621+
converter.getModuleOp().lookupSymbol(mapperIdName);
2622+
bool allowImplicitMapper =
2623+
semantics::IsAllocatableOrObjectPointer(&sym);
2624+
if (mapperExists || allowImplicitMapper) {
2625+
if (!mapperExists) {
2626+
auto recordType = mlir::dyn_cast_or_null<fir::RecordType>(
2627+
converter.genType(*typeSpec));
2628+
if (recordType)
2629+
mapperId = getOrGenImplicitDefaultDeclareMapper(
2630+
converter, loc, recordType, mapperIdName);
2631+
} else {
2632+
mapperId = mlir::FlatSymbolRefAttr::get(
2633+
&converter.getMLIRContext(), mapperIdName);
2634+
}
2635+
}
2636+
}
2637+
}
2638+
}
2639+
26122640
mlir::Value mapOp = createMapInfoOp(
26132641
firOpBuilder, converter.getCurrentLocation(), baseOp,
26142642
/*varPtrPtr=*/mlir::Value{}, name.str(), bounds, /*members=*/{},

flang/lib/Lower/OpenMP/Utils.cpp

Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,16 +20,21 @@
2020
#include <flang/Lower/OpenMP/Clauses.h>
2121
#include <flang/Lower/PFTBuilder.h>
2222
#include <flang/Lower/Support/PrivateReductionUtils.h>
23+
#include <flang/Optimizer/Builder/BoxValue.h>
2324
#include <flang/Optimizer/Builder/FIRBuilder.h>
2425
#include <flang/Optimizer/Builder/Todo.h>
26+
#include <flang/Optimizer/HLFIR/HLFIROps.h>
2527
#include <flang/Parser/openmp-utils.h>
2628
#include <flang/Parser/parse-tree.h>
2729
#include <flang/Parser/tools.h>
2830
#include <flang/Semantics/tools.h>
2931
#include <flang/Semantics/type.h>
3032
#include <flang/Utils/OpenMP.h>
33+
#include <llvm/ADT/SmallPtrSet.h>
34+
#include <llvm/ADT/StringRef.h>
3135
#include <llvm/Support/CommandLine.h>
3236

37+
#include <functional>
3338
#include <iterator>
3439

3540
template <typename T>
@@ -61,6 +66,139 @@ namespace Fortran {
6166
namespace lower {
6267
namespace omp {
6368

69+
mlir::FlatSymbolRefAttr getOrGenImplicitDefaultDeclareMapper(
70+
lower::AbstractConverter &converter, mlir::Location loc,
71+
fir::RecordType recordType, llvm::StringRef mapperNameStr) {
72+
if (converter.getModuleOp().lookupSymbol(mapperNameStr))
73+
return mlir::FlatSymbolRefAttr::get(&converter.getMLIRContext(),
74+
mapperNameStr);
75+
76+
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
77+
mlir::OpBuilder::InsertionGuard guard(firOpBuilder);
78+
79+
firOpBuilder.setInsertionPointToStart(converter.getModuleOp().getBody());
80+
auto declMapperOp = firOpBuilder.create<mlir::omp::DeclareMapperOp>(
81+
loc, mapperNameStr, recordType);
82+
auto &region = declMapperOp.getRegion();
83+
firOpBuilder.createBlock(&region);
84+
auto mapperArg = region.addArgument(firOpBuilder.getRefType(recordType), loc);
85+
86+
auto declareOp =
87+
firOpBuilder.create<hlfir::DeclareOp>(loc, mapperArg, /*uniq_name=*/"");
88+
89+
const auto genBoundsOps = [&](mlir::Value mapVal,
90+
llvm::SmallVectorImpl<mlir::Value> &bounds) {
91+
fir::ExtendedValue extVal =
92+
hlfir::translateToExtendedValue(mapVal.getLoc(), firOpBuilder,
93+
hlfir::Entity{mapVal},
94+
/*contiguousHint=*/true)
95+
.first;
96+
fir::factory::AddrAndBoundsInfo info = fir::factory::getDataOperandBaseAddr(
97+
firOpBuilder, mapVal, /*isOptional=*/false, mapVal.getLoc());
98+
bounds = fir::factory::genImplicitBoundsOps<mlir::omp::MapBoundsOp,
99+
mlir::omp::MapBoundsType>(
100+
firOpBuilder, info, extVal,
101+
/*dataExvIsAssumedSize=*/false, mapVal.getLoc());
102+
};
103+
104+
const auto getFieldRef = [&](mlir::Value rec, llvm::StringRef fieldName,
105+
mlir::Type fieldTy, mlir::Type recType) {
106+
mlir::Value field = firOpBuilder.create<fir::FieldIndexOp>(
107+
loc, fir::FieldType::get(recType.getContext()), fieldName, recType,
108+
fir::getTypeParams(rec));
109+
return firOpBuilder.create<fir::CoordinateOp>(
110+
loc, firOpBuilder.getRefType(fieldTy), rec, field);
111+
};
112+
113+
mlir::omp::DeclareMapperInfoOperands clauseOps;
114+
llvm::SmallVector<llvm::SmallVector<int64_t>> memberPlacementIndices;
115+
llvm::SmallVector<mlir::Value> memberMapOps;
116+
117+
mlir::omp::ClauseMapFlags mapFlag = mlir::omp::ClauseMapFlags::to;
118+
mapFlag |= mlir::omp::ClauseMapFlags::from;
119+
mapFlag |= mlir::omp::ClauseMapFlags::implicit;
120+
mlir::omp::VariableCaptureKind captureKind =
121+
mlir::omp::VariableCaptureKind::ByRef;
122+
123+
for (const auto &entry : llvm::enumerate(recordType.getTypeList())) {
124+
const auto &memberName = entry.value().first;
125+
const auto &memberType = entry.value().second;
126+
mlir::FlatSymbolRefAttr mapperId;
127+
if (auto recType = mlir::dyn_cast<fir::RecordType>(
128+
fir::getFortranElementType(memberType))) {
129+
std::string mapperIdName =
130+
recType.getName().str() + llvm::omp::OmpDefaultMapperName;
131+
if (auto *sym = converter.getCurrentScope().FindSymbol(mapperIdName))
132+
mapperIdName = converter.mangleName(mapperIdName, sym->owner());
133+
else if (auto *sym = converter.getCurrentScope().FindSymbol(memberName))
134+
mapperIdName = converter.mangleName(mapperIdName, sym->owner());
135+
136+
mapperId = getOrGenImplicitDefaultDeclareMapper(converter, loc, recType,
137+
mapperIdName);
138+
}
139+
140+
auto ref =
141+
getFieldRef(declareOp.getBase(), memberName, memberType, recordType);
142+
llvm::SmallVector<mlir::Value> bounds;
143+
genBoundsOps(ref, bounds);
144+
mlir::Value mapOp = Fortran::utils::openmp::createMapInfoOp(
145+
firOpBuilder, loc, ref, /*varPtrPtr=*/mlir::Value{}, /*name=*/"",
146+
bounds,
147+
/*members=*/{},
148+
/*membersIndex=*/mlir::ArrayAttr{}, mapFlag, captureKind, ref.getType(),
149+
/*partialMap=*/false, mapperId);
150+
memberMapOps.emplace_back(mapOp);
151+
memberPlacementIndices.emplace_back(
152+
llvm::SmallVector<int64_t>{(int64_t)entry.index()});
153+
}
154+
155+
llvm::SmallVector<mlir::Value> bounds;
156+
genBoundsOps(declareOp.getOriginalBase(), bounds);
157+
mlir::omp::ClauseMapFlags parentMapFlag = mlir::omp::ClauseMapFlags::implicit;
158+
mlir::omp::MapInfoOp mapOp = Fortran::utils::openmp::createMapInfoOp(
159+
firOpBuilder, loc, declareOp.getOriginalBase(),
160+
/*varPtrPtr=*/mlir::Value(), /*name=*/"", bounds, memberMapOps,
161+
firOpBuilder.create2DI64ArrayAttr(memberPlacementIndices), parentMapFlag,
162+
captureKind, declareOp.getType(0),
163+
/*partialMap=*/true);
164+
165+
clauseOps.mapVars.emplace_back(mapOp);
166+
firOpBuilder.create<mlir::omp::DeclareMapperInfoOp>(loc, clauseOps.mapVars);
167+
return mlir::FlatSymbolRefAttr::get(&converter.getMLIRContext(),
168+
mapperNameStr);
169+
}
170+
171+
bool requiresImplicitDefaultDeclareMapper(
172+
const semantics::DerivedTypeSpec &typeSpec) {
173+
std::string rawName = typeSpec.name().ToString();
174+
std::string loweredName = llvm::StringRef(rawName).lower();
175+
llvm::StringRef typeNameRef(loweredName);
176+
if (typeNameRef.contains("c_ptr") || typeNameRef.contains("c_funptr"))
177+
return true;
178+
179+
llvm::SmallPtrSet<const semantics::DerivedTypeSpec *, 8> visited;
180+
181+
std::function<bool(const semantics::DerivedTypeSpec &)> requiresMapper =
182+
[&](const semantics::DerivedTypeSpec &spec) -> bool {
183+
if (!visited.insert(&spec).second)
184+
return false;
185+
186+
semantics::DirectComponentIterator directComponents{spec};
187+
for (const semantics::Symbol &component : directComponents) {
188+
if (semantics::IsAllocatableOrPointer(component))
189+
return true;
190+
191+
if (const semantics::DeclTypeSpec *declType = component.GetType())
192+
if (const auto *nested = declType->AsDerived())
193+
if (requiresMapper(*nested))
194+
return true;
195+
}
196+
return false;
197+
};
198+
199+
return requiresMapper(typeSpec);
200+
}
201+
64202
int64_t getCollapseValue(const List<Clause> &clauses) {
65203
auto iter = llvm::find_if(clauses, [](const Clause &clause) {
66204
return clause.id == llvm::omp::Clause::OMPC_collapse;
@@ -537,6 +675,12 @@ void insertChildMapInfoIntoParent(
537675
mapOperands[std::distance(mapSyms.begin(), parentIter)]
538676
.getDefiningOp());
539677

678+
// Once explicit members are attached to a parent map, do not also invoke
679+
// a declare mapper on it, otherwise the mapper would remap the same
680+
// components leading to duplicate mappings at runtime.
681+
if (!indices.second.memberMap.empty() && mapOp.getMapperIdAttr())
682+
mapOp.setMapperIdAttr(nullptr);
683+
540684
// NOTE: To maintain appropriate SSA ordering, we move the parent map
541685
// which will now have references to its children after the last
542686
// of its members to be generated. This is necessary when a user

0 commit comments

Comments
 (0)