Skip to content

Commit 8aa7d82

Browse files
authored
[OpenMP][Flang] Emit default declare mappers implicitly for derived types (#140562)
This patch adds support to emit default declare mappers for implicit mapping of derived types when not supplied by user. This especially helps tackle mapping of allocatables of derived types.
1 parent 282bdb4 commit 8aa7d82

File tree

9 files changed

+377
-50
lines changed

9 files changed

+377
-50
lines changed

flang/lib/Lower/OpenMP/ClauseProcessor.cpp

Lines changed: 84 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include "flang/Lower/OpenMP/Clauses.h"
1818
#include "flang/Lower/PFTBuilder.h"
1919
#include "flang/Lower/Support/ReductionProcessor.h"
20+
#include "flang/Optimizer/Dialect/FIRType.h"
2021
#include "flang/Parser/tools.h"
2122
#include "flang/Semantics/tools.h"
2223
#include "flang/Utils/OpenMP.h"
@@ -1228,26 +1229,66 @@ void ClauseProcessor::processMapObjects(
12281229
llvm::StringRef mapperIdNameRef) const {
12291230
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
12301231

1231-
auto getDefaultMapperID = [&](const omp::Object &object,
1232-
std::string &mapperIdName) {
1233-
if (!mlir::isa<mlir::omp::DeclareMapperOp>(
1234-
firOpBuilder.getRegion().getParentOp())) {
1235-
const semantics::DerivedTypeSpec *typeSpec = nullptr;
1232+
auto getSymbolDerivedType = [](const semantics::Symbol &symbol)
1233+
-> const semantics::DerivedTypeSpec * {
1234+
const semantics::Symbol &ultimate = symbol.GetUltimate();
1235+
if (const semantics::DeclTypeSpec *declType = ultimate.GetType())
1236+
if (const auto *derived = declType->AsDerived())
1237+
return derived;
1238+
return nullptr;
1239+
};
12361240

1237-
if (object.sym()->owner().IsDerivedType())
1238-
typeSpec = object.sym()->owner().derivedTypeSpec();
1239-
else if (object.sym()->GetType() &&
1240-
object.sym()->GetType()->category() ==
1241-
semantics::DeclTypeSpec::TypeDerived)
1242-
typeSpec = &object.sym()->GetType()->derivedTypeSpec();
1243-
1244-
if (typeSpec) {
1245-
mapperIdName =
1246-
typeSpec->name().ToString() + llvm::omp::OmpDefaultMapperName;
1247-
if (auto *sym = converter.getCurrentScope().FindSymbol(mapperIdName))
1248-
mapperIdName = converter.mangleName(mapperIdName, sym->owner());
1249-
}
1241+
auto addImplicitMapper = [&](const omp::Object &object,
1242+
std::string &mapperIdName,
1243+
bool allowGenerate) -> mlir::FlatSymbolRefAttr {
1244+
if (mapperIdName.empty())
1245+
return mlir::FlatSymbolRefAttr();
1246+
1247+
if (converter.getModuleOp().lookupSymbol(mapperIdName))
1248+
return mlir::FlatSymbolRefAttr::get(&converter.getMLIRContext(),
1249+
mapperIdName);
1250+
1251+
if (!allowGenerate)
1252+
return mlir::FlatSymbolRefAttr();
1253+
1254+
const semantics::DerivedTypeSpec *typeSpec =
1255+
getSymbolDerivedType(*object.sym());
1256+
if (!typeSpec && object.sym()->owner().IsDerivedType())
1257+
typeSpec = object.sym()->owner().derivedTypeSpec();
1258+
1259+
if (!typeSpec)
1260+
return mlir::FlatSymbolRefAttr();
1261+
1262+
mlir::Type type = converter.genType(*typeSpec);
1263+
auto recordType = mlir::dyn_cast<fir::RecordType>(type);
1264+
if (!recordType)
1265+
return mlir::FlatSymbolRefAttr();
1266+
1267+
return getOrGenImplicitDefaultDeclareMapper(converter, clauseLocation,
1268+
recordType, mapperIdName);
1269+
};
1270+
1271+
auto getDefaultMapperID =
1272+
[&](const semantics::DerivedTypeSpec *typeSpec) -> std::string {
1273+
if (mlir::isa<mlir::omp::DeclareMapperOp>(
1274+
firOpBuilder.getRegion().getParentOp()) ||
1275+
!typeSpec)
1276+
return {};
1277+
1278+
std::string mapperIdName =
1279+
typeSpec->name().ToString() + llvm::omp::OmpDefaultMapperName;
1280+
if (auto *sym = converter.getCurrentScope().FindSymbol(mapperIdName)) {
1281+
mapperIdName = converter.mangleName(mapperIdName, sym->owner());
1282+
} else {
1283+
mapperIdName = converter.mangleName(mapperIdName, *typeSpec->GetScope());
12501284
}
1285+
1286+
// Make sure we don't return a mapper to self.
1287+
if (auto declMapOp = mlir::dyn_cast<mlir::omp::DeclareMapperOp>(
1288+
firOpBuilder.getRegion().getParentOp()))
1289+
if (mapperIdName == declMapOp.getSymName())
1290+
return {};
1291+
return mapperIdName;
12511292
};
12521293

12531294
// Create the mapper symbol from its name, if specified.
@@ -1256,8 +1297,13 @@ void ClauseProcessor::processMapObjects(
12561297
mapperIdNameRef != "__implicit_mapper") {
12571298
std::string mapperIdName = mapperIdNameRef.str();
12581299
const omp::Object &object = objects.front();
1259-
if (mapperIdNameRef == "default")
1260-
getDefaultMapperID(object, mapperIdName);
1300+
if (mapperIdNameRef == "default") {
1301+
const semantics::DerivedTypeSpec *typeSpec =
1302+
getSymbolDerivedType(*object.sym());
1303+
if (!typeSpec && object.sym()->owner().IsDerivedType())
1304+
typeSpec = object.sym()->owner().derivedTypeSpec();
1305+
mapperIdName = getDefaultMapperID(typeSpec);
1306+
}
12611307
assert(converter.getModuleOp().lookupSymbol(mapperIdName) &&
12621308
"mapper not found");
12631309
mapperId =
@@ -1295,13 +1341,25 @@ void ClauseProcessor::processMapObjects(
12951341
}
12961342
}
12971343

1344+
const semantics::DerivedTypeSpec *objectTypeSpec =
1345+
getSymbolDerivedType(*object.sym());
1346+
12981347
if (mapperIdNameRef == "__implicit_mapper") {
1299-
std::string mapperIdName;
1300-
getDefaultMapperID(object, mapperIdName);
1301-
mapperId = converter.getModuleOp().lookupSymbol(mapperIdName)
1302-
? mlir::FlatSymbolRefAttr::get(&converter.getMLIRContext(),
1303-
mapperIdName)
1304-
: mlir::FlatSymbolRefAttr();
1348+
if (parentObj.has_value()) {
1349+
mapperId = mlir::FlatSymbolRefAttr();
1350+
} else if (objectTypeSpec) {
1351+
std::string mapperIdName = getDefaultMapperID(objectTypeSpec);
1352+
bool needsDefaultMapper =
1353+
semantics::IsAllocatableOrObjectPointer(object.sym()) ||
1354+
requiresImplicitDefaultDeclareMapper(*objectTypeSpec);
1355+
if (!mapperIdName.empty())
1356+
mapperId = addImplicitMapper(object, mapperIdName,
1357+
/*allowGenerate=*/needsDefaultMapper);
1358+
else
1359+
mapperId = mlir::FlatSymbolRefAttr();
1360+
} else {
1361+
mapperId = mlir::FlatSymbolRefAttr();
1362+
}
13051363
}
13061364

13071365
// Explicit map captures are captured ByRef by default,

flang/lib/Lower/OpenMP/OpenMP.cpp

Lines changed: 38 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2581,18 +2581,6 @@ genTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
25812581
fir::ExtendedValue dataExv = converter.getSymbolExtendedValue(sym);
25822582
name << sym.name().ToString();
25832583

2584-
mlir::FlatSymbolRefAttr mapperId;
2585-
if (sym.GetType()->category() == semantics::DeclTypeSpec::TypeDerived) {
2586-
auto &typeSpec = sym.GetType()->derivedTypeSpec();
2587-
std::string mapperIdName =
2588-
typeSpec.name().ToString() + llvm::omp::OmpDefaultMapperName;
2589-
if (auto *sym = converter.getCurrentScope().FindSymbol(mapperIdName))
2590-
mapperIdName = converter.mangleName(mapperIdName, sym->owner());
2591-
if (converter.getModuleOp().lookupSymbol(mapperIdName))
2592-
mapperId = mlir::FlatSymbolRefAttr::get(&converter.getMLIRContext(),
2593-
mapperIdName);
2594-
}
2595-
25962584
fir::factory::AddrAndBoundsInfo info =
25972585
Fortran::lower::getDataOperandBaseAddr(
25982586
converter, firOpBuilder, sym.GetUltimate(),
@@ -2612,6 +2600,44 @@ genTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
26122600
mapFlagAndKind = getImplicitMapTypeAndKind(
26132601
firOpBuilder, converter, defaultMaps, eleType, loc, sym);
26142602

2603+
mlir::FlatSymbolRefAttr mapperId;
2604+
if (defaultMaps.empty()) {
2605+
// TODO: Honor user-provided defaultmap clauses (aggregates/pointers)
2606+
// instead of blanket-disabling implicit mapper generation whenever any
2607+
// explicit default map is present.
2608+
const semantics::DerivedTypeSpec *typeSpec =
2609+
sym.GetType() ? sym.GetType()->AsDerived() : nullptr;
2610+
if (typeSpec) {
2611+
std::string mapperIdName =
2612+
typeSpec->name().ToString() + llvm::omp::OmpDefaultMapperName;
2613+
if (auto *mapperSym =
2614+
converter.getCurrentScope().FindSymbol(mapperIdName))
2615+
mapperIdName =
2616+
converter.mangleName(mapperIdName, mapperSym->owner());
2617+
else
2618+
mapperIdName =
2619+
converter.mangleName(mapperIdName, *typeSpec->GetScope());
2620+
2621+
if (!mapperIdName.empty()) {
2622+
bool allowImplicitMapper =
2623+
semantics::IsAllocatableOrObjectPointer(&sym);
2624+
bool hasDefaultMapper =
2625+
converter.getModuleOp().lookupSymbol(mapperIdName);
2626+
if (hasDefaultMapper || allowImplicitMapper) {
2627+
if (!hasDefaultMapper) {
2628+
if (auto recordType = mlir::dyn_cast_or_null<fir::RecordType>(
2629+
converter.genType(*typeSpec)))
2630+
mapperId = getOrGenImplicitDefaultDeclareMapper(
2631+
converter, loc, recordType, mapperIdName);
2632+
} else {
2633+
mapperId = mlir::FlatSymbolRefAttr::get(
2634+
&converter.getMLIRContext(), mapperIdName);
2635+
}
2636+
}
2637+
}
2638+
}
2639+
}
2640+
26152641
mlir::Value mapOp = createMapInfoOp(
26162642
firOpBuilder, converter.getCurrentLocation(), baseOp,
26172643
/*varPtrPtr=*/mlir::Value{}, name.str(), bounds, /*members=*/{},

flang/lib/Lower/OpenMP/Utils.cpp

Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,22 +14,28 @@
1414

1515
#include "ClauseFinder.h"
1616
#include "flang/Evaluate/fold.h"
17+
#include "flang/Evaluate/tools.h"
1718
#include <flang/Lower/AbstractConverter.h>
1819
#include <flang/Lower/ConvertType.h>
1920
#include <flang/Lower/DirectivesCommon.h>
2021
#include <flang/Lower/OpenMP/Clauses.h>
2122
#include <flang/Lower/PFTBuilder.h>
2223
#include <flang/Lower/Support/PrivateReductionUtils.h>
24+
#include <flang/Optimizer/Builder/BoxValue.h>
2325
#include <flang/Optimizer/Builder/FIRBuilder.h>
2426
#include <flang/Optimizer/Builder/Todo.h>
27+
#include <flang/Optimizer/HLFIR/HLFIROps.h>
2528
#include <flang/Parser/openmp-utils.h>
2629
#include <flang/Parser/parse-tree.h>
2730
#include <flang/Parser/tools.h>
2831
#include <flang/Semantics/tools.h>
2932
#include <flang/Semantics/type.h>
3033
#include <flang/Utils/OpenMP.h>
34+
#include <llvm/ADT/SmallPtrSet.h>
35+
#include <llvm/ADT/StringRef.h>
3136
#include <llvm/Support/CommandLine.h>
3237

38+
#include <functional>
3339
#include <iterator>
3440

3541
template <typename T>
@@ -61,6 +67,142 @@ namespace Fortran {
6167
namespace lower {
6268
namespace omp {
6369

70+
mlir::FlatSymbolRefAttr getOrGenImplicitDefaultDeclareMapper(
71+
lower::AbstractConverter &converter, mlir::Location loc,
72+
fir::RecordType recordType, llvm::StringRef mapperNameStr) {
73+
if (mapperNameStr.empty())
74+
return {};
75+
76+
if (converter.getModuleOp().lookupSymbol(mapperNameStr))
77+
return mlir::FlatSymbolRefAttr::get(&converter.getMLIRContext(),
78+
mapperNameStr);
79+
80+
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
81+
mlir::OpBuilder::InsertionGuard guard(firOpBuilder);
82+
83+
firOpBuilder.setInsertionPointToStart(converter.getModuleOp().getBody());
84+
auto declMapperOp = mlir::omp::DeclareMapperOp::create(
85+
firOpBuilder, loc, mapperNameStr, recordType);
86+
auto &region = declMapperOp.getRegion();
87+
firOpBuilder.createBlock(&region);
88+
auto mapperArg = region.addArgument(firOpBuilder.getRefType(recordType), loc);
89+
90+
auto declareOp = hlfir::DeclareOp::create(firOpBuilder, loc, mapperArg,
91+
/*uniq_name=*/"");
92+
93+
const auto genBoundsOps = [&](mlir::Value mapVal,
94+
llvm::SmallVectorImpl<mlir::Value> &bounds) {
95+
fir::ExtendedValue extVal =
96+
hlfir::translateToExtendedValue(mapVal.getLoc(), firOpBuilder,
97+
hlfir::Entity{mapVal},
98+
/*contiguousHint=*/true)
99+
.first;
100+
fir::factory::AddrAndBoundsInfo info = fir::factory::getDataOperandBaseAddr(
101+
firOpBuilder, mapVal, /*isOptional=*/false, mapVal.getLoc());
102+
bounds = fir::factory::genImplicitBoundsOps<mlir::omp::MapBoundsOp,
103+
mlir::omp::MapBoundsType>(
104+
firOpBuilder, info, extVal,
105+
/*dataExvIsAssumedSize=*/false, mapVal.getLoc());
106+
};
107+
108+
const auto getFieldRef = [&](mlir::Value rec, llvm::StringRef fieldName,
109+
mlir::Type fieldTy, mlir::Type recType) {
110+
mlir::Value field = fir::FieldIndexOp::create(
111+
firOpBuilder, loc, fir::FieldType::get(recType.getContext()), fieldName,
112+
recType, fir::getTypeParams(rec));
113+
return fir::CoordinateOp::create(
114+
firOpBuilder, loc, firOpBuilder.getRefType(fieldTy), rec, field);
115+
};
116+
117+
llvm::SmallVector<mlir::Value> clauseMapVars;
118+
llvm::SmallVector<llvm::SmallVector<int64_t>> memberPlacementIndices;
119+
llvm::SmallVector<mlir::Value> memberMapOps;
120+
121+
mlir::omp::ClauseMapFlags mapFlag = mlir::omp::ClauseMapFlags::to |
122+
mlir::omp::ClauseMapFlags::from |
123+
mlir::omp::ClauseMapFlags::implicit;
124+
mlir::omp::VariableCaptureKind captureKind =
125+
mlir::omp::VariableCaptureKind::ByRef;
126+
127+
for (const auto &entry : llvm::enumerate(recordType.getTypeList())) {
128+
const auto &memberName = entry.value().first;
129+
const auto &memberType = entry.value().second;
130+
mlir::FlatSymbolRefAttr mapperId;
131+
if (auto recType = mlir::dyn_cast<fir::RecordType>(
132+
fir::getFortranElementType(memberType))) {
133+
std::string mapperIdName =
134+
recType.getName().str() + llvm::omp::OmpDefaultMapperName;
135+
if (auto *sym = converter.getCurrentScope().FindSymbol(mapperIdName))
136+
mapperIdName = converter.mangleName(mapperIdName, sym->owner());
137+
else if (auto *memberSym =
138+
converter.getCurrentScope().FindSymbol(memberName))
139+
mapperIdName = converter.mangleName(mapperIdName, memberSym->owner());
140+
141+
mapperId = getOrGenImplicitDefaultDeclareMapper(converter, loc, recType,
142+
mapperIdName);
143+
}
144+
145+
auto ref =
146+
getFieldRef(declareOp.getBase(), memberName, memberType, recordType);
147+
llvm::SmallVector<mlir::Value> bounds;
148+
genBoundsOps(ref, bounds);
149+
mlir::Value mapOp = Fortran::utils::openmp::createMapInfoOp(
150+
firOpBuilder, loc, ref, /*varPtrPtr=*/mlir::Value{}, /*name=*/"",
151+
bounds,
152+
/*members=*/{},
153+
/*membersIndex=*/mlir::ArrayAttr{}, mapFlag, captureKind, ref.getType(),
154+
/*partialMap=*/false, mapperId);
155+
memberMapOps.emplace_back(mapOp);
156+
memberPlacementIndices.emplace_back(
157+
llvm::SmallVector<int64_t>{(int64_t)entry.index()});
158+
}
159+
160+
llvm::SmallVector<mlir::Value> bounds;
161+
genBoundsOps(declareOp.getOriginalBase(), bounds);
162+
mlir::omp::ClauseMapFlags parentMapFlag = mlir::omp::ClauseMapFlags::implicit;
163+
mlir::omp::MapInfoOp mapOp = Fortran::utils::openmp::createMapInfoOp(
164+
firOpBuilder, loc, declareOp.getOriginalBase(),
165+
/*varPtrPtr=*/mlir::Value(), /*name=*/"", bounds, memberMapOps,
166+
firOpBuilder.create2DI64ArrayAttr(memberPlacementIndices), parentMapFlag,
167+
captureKind, declareOp.getType(0),
168+
/*partialMap=*/true);
169+
170+
clauseMapVars.emplace_back(mapOp);
171+
mlir::omp::DeclareMapperInfoOp::create(firOpBuilder, loc, clauseMapVars);
172+
return mlir::FlatSymbolRefAttr::get(&converter.getMLIRContext(),
173+
mapperNameStr);
174+
}
175+
176+
bool requiresImplicitDefaultDeclareMapper(
177+
const semantics::DerivedTypeSpec &typeSpec) {
178+
// ISO C interoperable types (e.g., c_ptr, c_funptr) must always have implicit
179+
// default mappers available so that OpenMP offloading can correctly map them.
180+
if (semantics::IsIsoCType(&typeSpec))
181+
return true;
182+
183+
llvm::SmallPtrSet<const semantics::DerivedTypeSpec *, 8> visited;
184+
185+
std::function<bool(const semantics::DerivedTypeSpec &)> requiresMapper =
186+
[&](const semantics::DerivedTypeSpec &spec) -> bool {
187+
if (!visited.insert(&spec).second)
188+
return false;
189+
190+
semantics::DirectComponentIterator directComponents{spec};
191+
for (const semantics::Symbol &component : directComponents) {
192+
if (semantics::IsAllocatableOrPointer(component))
193+
return true;
194+
195+
if (const semantics::DeclTypeSpec *declType = component.GetType())
196+
if (const auto *nested = declType->AsDerived())
197+
if (requiresMapper(*nested))
198+
return true;
199+
}
200+
return false;
201+
};
202+
203+
return requiresMapper(typeSpec);
204+
}
205+
64206
int64_t getCollapseValue(const List<Clause> &clauses) {
65207
auto iter = llvm::find_if(clauses, [](const Clause &clause) {
66208
return clause.id == llvm::omp::Clause::OMPC_collapse;
@@ -537,6 +679,12 @@ void insertChildMapInfoIntoParent(
537679
mapOperands[std::distance(mapSyms.begin(), parentIter)]
538680
.getDefiningOp());
539681

682+
// Once explicit members are attached to a parent map, do not also invoke
683+
// a declare mapper on it, otherwise the mapper would remap the same
684+
// components leading to duplicate mappings at runtime.
685+
if (!indices.second.memberMap.empty() && mapOp.getMapperIdAttr())
686+
mapOp.setMapperIdAttr(nullptr);
687+
540688
// NOTE: To maintain appropriate SSA ordering, we move the parent map
541689
// which will now have references to its children after the last
542690
// of its members to be generated. This is necessary when a user

0 commit comments

Comments
 (0)