Skip to content

Commit 833ffa5

Browse files
authored
[Flang][OpenMP] Update declare mapper lookup via use-module (#167903)
1 parent d4c8cfe commit 833ffa5

File tree

13 files changed

+199
-37
lines changed

13 files changed

+199
-37
lines changed

flang/include/flang/Lower/OpenMP.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ struct OmpClauseList;
4141

4242
namespace semantics {
4343
class Symbol;
44+
class Scope;
4445
class SemanticsContext;
4546
} // namespace semantics
4647

@@ -97,6 +98,13 @@ bool markOpenMPDeferredDeclareTargetFunctions(
9798
AbstractConverter &);
9899
void genOpenMPRequires(mlir::Operation *, const Fortran::semantics::Symbol *);
99100

101+
// Materialize omp.declare_mapper ops for mapper declarations found in
102+
// imported modules. If \p scope is null, materialize for the whole
103+
// semantics global scope; otherwise, operate recursively starting at \p scope.
104+
void materializeOpenMPDeclareMappers(
105+
Fortran::lower::AbstractConverter &, Fortran::semantics::SemanticsContext &,
106+
const Fortran::semantics::Scope *scope = nullptr);
107+
100108
} // namespace lower
101109
} // namespace Fortran
102110

flang/include/flang/Semantics/symbol.h

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -777,14 +777,32 @@ class UserReductionDetails {
777777
DeclVector declList_;
778778
};
779779

780+
// Used for OpenMP DECLARE MAPPER, it holds the declaration constructs
781+
// so they can be serialized into module files and later re-parsed when
782+
// USE-associated.
783+
class MapperDetails {
784+
public:
785+
using DeclVector = std::vector<const parser::OpenMPDeclarativeConstruct *>;
786+
787+
MapperDetails() = default;
788+
789+
void AddDecl(const parser::OpenMPDeclarativeConstruct *decl) {
790+
declList_.emplace_back(decl);
791+
}
792+
const DeclVector &GetDeclList() const { return declList_; }
793+
794+
private:
795+
DeclVector declList_;
796+
};
797+
780798
class UnknownDetails {};
781799

782800
using Details = std::variant<UnknownDetails, MainProgramDetails, ModuleDetails,
783801
SubprogramDetails, SubprogramNameDetails, EntityDetails,
784802
ObjectEntityDetails, ProcEntityDetails, AssocEntityDetails,
785803
DerivedTypeDetails, UseDetails, UseErrorDetails, HostAssocDetails,
786804
GenericDetails, ProcBindingDetails, NamelistDetails, CommonBlockDetails,
787-
TypeParamDetails, MiscDetails, UserReductionDetails>;
805+
TypeParamDetails, MiscDetails, UserReductionDetails, MapperDetails>;
788806
llvm::raw_ostream &operator<<(llvm::raw_ostream &, const Details &);
789807
std::string DetailsToString(const Details &);
790808

flang/lib/Lower/Bridge.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -448,6 +448,13 @@ class FirConverter : public Fortran::lower::AbstractConverter {
448448
}
449449
});
450450

451+
// Ensure imported OpenMP declare mappers are materialized at module
452+
// scope before lowering any constructs that may reference them.
453+
createBuilderOutsideOfFuncOpAndDo([&]() {
454+
Fortran::lower::materializeOpenMPDeclareMappers(
455+
*this, bridge.getSemanticsContext());
456+
});
457+
451458
// Create definitions of intrinsic module constants.
452459
createBuilderOutsideOfFuncOpAndDo(
453460
[&]() { createIntrinsicModuleDefinitions(pft); });

flang/lib/Lower/OpenMP/ClauseProcessor.cpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1397,10 +1397,14 @@ bool ClauseProcessor::processMap(
13971397
}
13981398
if (mappers) {
13991399
assert(mappers->size() == 1 && "more than one mapper");
1400-
mapperIdName = mappers->front().v.id().symbol->name().ToString();
1401-
if (mapperIdName != "default")
1402-
mapperIdName = converter.mangleName(
1403-
mapperIdName, mappers->front().v.id().symbol->owner());
1400+
const semantics::Symbol *mapperSym = mappers->front().v.id().symbol;
1401+
mapperIdName = mapperSym->name().ToString();
1402+
if (mapperIdName != "default") {
1403+
// Mangle with the ultimate owner so that use-associated mapper
1404+
// identifiers resolve to the same symbol as their defining scope.
1405+
const semantics::Symbol &ultimate = mapperSym->GetUltimate();
1406+
mapperIdName = converter.mangleName(mapperIdName, ultimate.owner());
1407+
}
14041408
}
14051409

14061410
processMapObjects(stmtCtx, clauseLocation,

flang/lib/Lower/OpenMP/OpenMP.cpp

Lines changed: 54 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3553,10 +3553,10 @@ genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable,
35533553
TODO(converter.getCurrentLocation(), "OpenMPDeclareSimdConstruct");
35543554
}
35553555

3556-
static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable,
3557-
semantics::SemanticsContext &semaCtx,
3558-
lower::pft::Evaluation &eval,
3559-
const parser::OpenMPDeclareMapperConstruct &construct) {
3556+
static void genOpenMPDeclareMapperImpl(
3557+
lower::AbstractConverter &converter, semantics::SemanticsContext &semaCtx,
3558+
const parser::OpenMPDeclareMapperConstruct &construct,
3559+
const semantics::Symbol *mapperSymOpt = nullptr) {
35603560
mlir::Location loc = converter.genLocation(construct.source);
35613561
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
35623562
const parser::OmpArgumentList &args = construct.v.Arguments();
@@ -3572,8 +3572,17 @@ static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable,
35723572
"Expected derived type");
35733573

35743574
std::string mapperNameStr = mapperName;
3575-
if (auto *sym = converter.getCurrentScope().FindSymbol(mapperNameStr))
3575+
if (mapperSymOpt && mapperNameStr != "default") {
3576+
mapperNameStr = converter.mangleName(mapperNameStr, mapperSymOpt->owner());
3577+
} else if (auto *sym =
3578+
converter.getCurrentScope().FindSymbol(mapperNameStr)) {
35763579
mapperNameStr = converter.mangleName(mapperNameStr, sym->owner());
3580+
}
3581+
3582+
// If the mapper op already exists (e.g., created by regular lowering or by
3583+
// materialization of imported mappers), do not recreate it.
3584+
if (converter.getModuleOp().lookupSymbol(mapperNameStr))
3585+
return;
35773586

35783587
// Save current insertion point before moving to the module scope to create
35793588
// the DeclareMapperOp
@@ -3596,6 +3605,13 @@ static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable,
35963605
mlir::omp::DeclareMapperInfoOp::create(firOpBuilder, loc, clauseOps.mapVars);
35973606
}
35983607

3608+
static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable,
3609+
semantics::SemanticsContext &semaCtx,
3610+
lower::pft::Evaluation &eval,
3611+
const parser::OpenMPDeclareMapperConstruct &construct) {
3612+
genOpenMPDeclareMapperImpl(converter, semaCtx, construct);
3613+
}
3614+
35993615
static void
36003616
genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable,
36013617
semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval,
@@ -4231,3 +4247,36 @@ void Fortran::lower::genOpenMPRequires(mlir::Operation *mod,
42314247
offloadMod.setRequires(mlirFlags);
42324248
}
42334249
}
4250+
4251+
// Walk scopes and materialize omp.declare_mapper ops for mapper declarations
4252+
// found in imported modules. If \p scope is null, start from the global scope.
4253+
void Fortran::lower::materializeOpenMPDeclareMappers(
4254+
Fortran::lower::AbstractConverter &converter,
4255+
semantics::SemanticsContext &semaCtx, const semantics::Scope *scope) {
4256+
const semantics::Scope &root = scope ? *scope : semaCtx.globalScope();
4257+
4258+
// Recurse into child scopes first (modules, submodules, etc.).
4259+
for (const semantics::Scope &child : root.children())
4260+
materializeOpenMPDeclareMappers(converter, semaCtx, &child);
4261+
4262+
// Only consider module scopes to avoid duplicating local constructs.
4263+
if (!root.IsModule())
4264+
return;
4265+
4266+
// Only materialize for modules coming from mod files to avoid duplicates.
4267+
if (!root.symbol() || !root.symbol()->test(semantics::Symbol::Flag::ModFile))
4268+
return;
4269+
4270+
// Scan symbols in this module scope for MapperDetails.
4271+
for (auto &it : root) {
4272+
const semantics::Symbol &sym = *it.second;
4273+
if (auto *md = sym.detailsIf<semantics::MapperDetails>()) {
4274+
for (const auto *decl : md->GetDeclList()) {
4275+
if (const auto *mapperDecl =
4276+
std::get_if<parser::OpenMPDeclareMapperConstruct>(&decl->u)) {
4277+
genOpenMPDeclareMapperImpl(converter, semaCtx, *mapperDecl, &sym);
4278+
}
4279+
}
4280+
}
4281+
}
4282+
}

flang/lib/Semantics/mod-file.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ static void PutBound(llvm::raw_ostream &, const Bound &);
5959
static void PutShapeSpec(llvm::raw_ostream &, const ShapeSpec &);
6060
static void PutShape(
6161
llvm::raw_ostream &, const ArraySpec &, char open, char close);
62+
static void PutMapper(llvm::raw_ostream &, const Symbol &, SemanticsContext &);
6263

6364
static llvm::raw_ostream &PutAttr(llvm::raw_ostream &, Attr);
6465
static llvm::raw_ostream &PutType(llvm::raw_ostream &, const DeclTypeSpec &);
@@ -938,6 +939,7 @@ void ModFileWriter::PutEntity(llvm::raw_ostream &os, const Symbol &symbol) {
938939
[&](const ProcEntityDetails &) { PutProcEntity(os, symbol); },
939940
[&](const TypeParamDetails &) { PutTypeParam(os, symbol); },
940941
[&](const UserReductionDetails &) { PutUserReduction(os, symbol); },
942+
[&](const MapperDetails &) { PutMapper(decls_, symbol, context_); },
941943
[&](const auto &) {
942944
common::die("PutEntity: unexpected details: %s",
943945
DetailsToString(symbol.details()).c_str());
@@ -1101,6 +1103,16 @@ void ModFileWriter::PutUserReduction(
11011103
}
11021104
}
11031105

1106+
static void PutMapper(
1107+
llvm::raw_ostream &os, const Symbol &symbol, SemanticsContext &context) {
1108+
const auto &details{symbol.get<MapperDetails>()};
1109+
// Emit each saved DECLARE MAPPER construct as-is, so that consumers of the
1110+
// module can reparse it and recreate the mapper symbol and semantics state.
1111+
for (const auto *decl : details.GetDeclList()) {
1112+
Unparse(os, *decl, context.langOptions());
1113+
}
1114+
}
1115+
11041116
void PutInit(llvm::raw_ostream &os, const Symbol &symbol, const MaybeExpr &init,
11051117
const parser::Expr *unanalyzed, SemanticsContext &context) {
11061118
if (IsNamedConstant(symbol) || symbol.owner().IsDerivedType()) {

flang/lib/Semantics/resolve-names.cpp

Lines changed: 35 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1852,21 +1852,25 @@ bool OmpVisitor::Pre(const parser::OmpMapClause &x) {
18521852
// TODO: Do we need a specific flag or type here, to distinghuish against
18531853
// other ConstructName things? Leaving this for the full implementation
18541854
// of mapper lowering.
1855-
auto *misc{symbol->detailsIf<MiscDetails>()};
1856-
if (!misc || misc->kind() != MiscDetails::Kind::ConstructName)
1855+
auto &ultimate{symbol->GetUltimate()};
1856+
auto *misc{ultimate.detailsIf<MiscDetails>()};
1857+
auto *md{ultimate.detailsIf<MapperDetails>()};
1858+
if (!md && (!misc || misc->kind() != MiscDetails::Kind::ConstructName))
18571859
context().Say(mapper->v.source,
18581860
"Name '%s' should be a mapper name"_err_en_US, mapper->v.source);
18591861
else
18601862
mapper->v.symbol = symbol;
18611863
} else {
1862-
mapper->v.symbol =
1863-
&MakeSymbol(mapper->v, MiscDetails{MiscDetails::Kind::ConstructName});
1864-
// TODO: When completing the implementation, we probably want to error if
1865-
// the symbol is not declared, but right now, testing that the TODO for
1866-
// OmpMapClause happens is obscured by the TODO for declare mapper, so
1867-
// leaving this out. Remove the above line once the declare mapper is
1868-
// implemented. context().Say(mapper->v.source, "'%s' not
1869-
// declared"_err_en_US, mapper->v.source);
1864+
// Allow the special 'default' mapper identifier without prior
1865+
// declaration so lowering can recognize and handle it. Emit an
1866+
// error for any other missing mapper identifier.
1867+
if (mapper->v.source.ToString() == "default") {
1868+
mapper->v.symbol = &MakeSymbol(
1869+
mapper->v, MiscDetails{MiscDetails::Kind::ConstructName});
1870+
} else {
1871+
context().Say(
1872+
mapper->v.source, "'%s' not declared"_err_en_US, mapper->v.source);
1873+
}
18701874
}
18711875
}
18721876
return true;
@@ -1880,8 +1884,16 @@ void OmpVisitor::ProcessMapperSpecifier(const parser::OmpMapperSpecifier &spec,
18801884
// the type has been fully processed.
18811885
BeginDeclTypeSpec();
18821886
auto &mapperName{std::get<std::string>(spec.t)};
1883-
MakeSymbol(parser::CharBlock(mapperName), Attrs{},
1884-
MiscDetails{MiscDetails::Kind::ConstructName});
1887+
// Create or update the mapper symbol with MapperDetails and
1888+
// keep track of the declarative construct for module emission.
1889+
SourceName mapperSource{context().SaveTempName(std::string{mapperName})};
1890+
Symbol &mapperSym{MakeSymbol(mapperSource, Attrs{})};
1891+
if (!mapperSym.detailsIf<MapperDetails>()) {
1892+
mapperSym.set_details(MapperDetails{});
1893+
}
1894+
if (!context().langOptions().OpenMPSimd) {
1895+
mapperSym.get<MapperDetails>().AddDecl(declaratives_.back());
1896+
}
18851897
PushScope(Scope::Kind::OtherConstruct, nullptr);
18861898
Walk(std::get<parser::TypeSpec>(spec.t));
18871899
auto &varName{std::get<parser::Name>(spec.t)};
@@ -3611,10 +3623,20 @@ void ModuleVisitor::Post(const parser::UseStmt &x) {
36113623
rename.u);
36123624
}
36133625
for (const auto &[name, symbol] : *useModuleScope_) {
3626+
// Default USE imports public names, excluding intrinsic-only and most
3627+
// miscellaneous details. Allow OpenMP mapper identifiers represented
3628+
// as MapperDetails, and also legacy MiscDetails::ConstructName.
3629+
bool isMapper{symbol->has<MapperDetails>()};
3630+
if (!isMapper) {
3631+
if (const auto *misc{symbol->detailsIf<MiscDetails>()}) {
3632+
isMapper = misc->kind() == MiscDetails::Kind::ConstructName;
3633+
}
3634+
}
36143635
if (symbol->attrs().test(Attr::PUBLIC) && !IsUseRenamed(symbol->name()) &&
36153636
(!symbol->implicitAttrs().test(Attr::INTRINSIC) ||
36163637
symbol->has<UseDetails>()) &&
3617-
!symbol->has<MiscDetails>() && useNames.count(name) == 0) {
3638+
(!symbol->has<MiscDetails>() || isMapper) &&
3639+
useNames.count(name) == 0) {
36183640
SourceName location{x.moduleName.source};
36193641
if (auto *localSymbol{FindInScope(name)}) {
36203642
DoAddUse(location, localSymbol->name(), *localSymbol, *symbol);

flang/lib/Semantics/symbol.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -338,7 +338,8 @@ std::string DetailsToString(const Details &details) {
338338
[](const TypeParamDetails &) { return "TypeParam"; },
339339
[](const MiscDetails &) { return "Misc"; },
340340
[](const AssocEntityDetails &) { return "AssocEntity"; },
341-
[](const UserReductionDetails &) { return "UserReductionDetails"; }},
341+
[](const UserReductionDetails &) { return "UserReductionDetails"; },
342+
[](const MapperDetails &) { return "MapperDetails"; }},
342343
details);
343344
}
344345

@@ -379,6 +380,7 @@ bool Symbol::CanReplaceDetails(const Details &details) const {
379380
[&](const UserReductionDetails &) {
380381
return has<UserReductionDetails>();
381382
},
383+
[&](const MapperDetails &) { return has<MapperDetails>(); },
382384
[](const auto &) { return false; },
383385
},
384386
details);
@@ -685,6 +687,8 @@ llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const Details &details) {
685687
DumpType(os, type);
686688
}
687689
},
690+
// Avoid recursive streaming for MapperDetails; nothing more to dump
691+
[&](const MapperDetails &) {},
688692
[&](const auto &x) { os << x; },
689693
},
690694
details);

flang/test/Lower/OpenMP/declare-mapper.f90

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@
66
! RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=50 %t/omp-declare-mapper-3.f90 -o - | FileCheck %t/omp-declare-mapper-3.f90
77
! RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=50 %t/omp-declare-mapper-4.f90 -o - | FileCheck %t/omp-declare-mapper-4.f90
88
! RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=50 %t/omp-declare-mapper-5.f90 -o - | FileCheck %t/omp-declare-mapper-5.f90
9-
! RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=51 %t/omp-declare-mapper-6.f90 -o - | FileCheck %t/omp-declare-mapper-6.f90
9+
! RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=50 %t/omp-declare-mapper-6.f90 -o - | FileCheck %t/omp-declare-mapper-6.f90
10+
! RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=50 -module-dir %t %t/omp-declare-mapper-7.mod.f90 -o - >/dev/null
11+
! RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=50 -J %t %t/omp-declare-mapper-7.use.f90 -o - | FileCheck %t/omp-declare-mapper-7.use.f90
1012

1113
!--- omp-declare-mapper-1.f90
1214
subroutine declare_mapper_1
@@ -301,3 +303,25 @@ subroutine declare_mapper_nested_parent
301303
r%real_arr = r%base_arr(1) + r%inner%deep_arr(1)
302304
!$omp end target
303305
end subroutine declare_mapper_nested_parent
306+
307+
!--- omp-declare-mapper-7.mod.f90
308+
! Module with DECLARE MAPPER to be compiled separately
309+
module m_mod
310+
implicit none
311+
type :: mty
312+
integer :: x
313+
end type mty
314+
!$omp declare mapper(mymap : mty :: v) map(tofrom: v%x)
315+
end module m_mod
316+
317+
!--- omp-declare-mapper-7.use.f90
318+
! Consumer program that USEs the module and applies the mapper by name.
319+
! CHECK: %{{.*}} = omp.map.info {{.*}} mapper(@{{.*mymap}}) {{.*}} {name = "a"}
320+
program use_module_mapper
321+
use m_mod
322+
implicit none
323+
type(mty) :: a
324+
!$omp target map(mapper(mymap) : a)
325+
a%x = 42
326+
!$omp end target
327+
end program use_module_mapper

flang/test/Parser/OpenMP/map-modifiers.f90

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -320,7 +320,7 @@ subroutine f21(x, y)
320320
integer :: x(10)
321321
integer :: y
322322
integer, parameter :: p = 23
323-
!$omp target map(mapper(xx), from: x)
323+
!$omp target map(mapper(default), from: x)
324324
x = x + 1
325325
!$omp end target
326326
end
@@ -329,15 +329,15 @@ subroutine f21(x, y)
329329
!UNPARSE: INTEGER x(10_4)
330330
!UNPARSE: INTEGER y
331331
!UNPARSE: INTEGER, PARAMETER :: p = 23_4
332-
!UNPARSE: !$OMP TARGET MAP(MAPPER(XX), FROM: X)
332+
!UNPARSE: !$OMP TARGET MAP(MAPPER(DEFAULT), FROM: X)
333333
!UNPARSE: x=x+1_4
334334
!UNPARSE: !$OMP END TARGET
335335
!UNPARSE: END SUBROUTINE
336336

337337
!PARSE-TREE: OmpBeginDirective
338338
!PARSE-TREE: | OmpDirectiveName -> llvm::omp::Directive = target
339339
!PARSE-TREE: | OmpClauseList -> OmpClause -> Map -> OmpMapClause
340-
!PARSE-TREE: | | Modifier -> OmpMapper -> Name = 'xx'
340+
!PARSE-TREE: | | Modifier -> OmpMapper -> Name = 'default'
341341
!PARSE-TREE: | | Modifier -> OmpMapType -> Value = From
342342
!PARSE-TREE: | | OmpObjectList -> OmpObject -> Designator -> DataRef -> Name = 'x'
343343

@@ -375,4 +375,3 @@ subroutine f22(x)
375375
!PARSE-TREE: | | SectionSubscript -> Integer -> Expr = 'i'
376376
!PARSE-TREE: | | | Designator -> DataRef -> Name = 'i'
377377
!PARSE-TREE: | bool = 'true'
378-

0 commit comments

Comments
 (0)