Skip to content

Commit 9dded5e

Browse files
committed
[OpenMP][Flang] Emit default declare mappers implicitly for derived types
This patch adds support to emit default declare mappers for implicit mapping of derived types when not supplied by user. This especially helps tackle mapping of allocatables of derived types. This supports nested derived types as well.
1 parent 1381ad4 commit 9dded5e

File tree

4 files changed

+160
-4
lines changed

4 files changed

+160
-4
lines changed

flang/lib/Lower/OpenMP/ClauseProcessor.cpp

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1216,8 +1216,19 @@ void ClauseProcessor::processMapObjects(
12161216
typeSpec->name().ToString() + llvm::omp::OmpDefaultMapperName;
12171217
if (auto *sym = converter.getCurrentScope().FindSymbol(mapperIdName))
12181218
mapperIdName = converter.mangleName(mapperIdName, sym->owner());
1219+
else
1220+
mapperIdName =
1221+
converter.mangleName(mapperIdName, *typeSpec->GetScope());
12191222
}
12201223
}
1224+
1225+
// Make sure we don't return a mapper to self.
1226+
llvm::StringRef parentOpName;
1227+
if (auto declMapOp = mlir::dyn_cast<mlir::omp::DeclareMapperOp>(
1228+
firOpBuilder.getRegion().getParentOp()))
1229+
parentOpName = declMapOp.getSymName();
1230+
if (mapperIdName == parentOpName)
1231+
mapperIdName = "";
12211232
};
12221233

12231234
// Create the mapper symbol from its name, if specified.
@@ -1322,7 +1333,7 @@ bool ClauseProcessor::processMap(
13221333
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE;
13231334
std::string mapperIdName = "__implicit_mapper";
13241335
// If the map type is specified, then process it else set the appropriate
1325-
// default value
1336+
// default value.
13261337
Map::MapType type;
13271338
if (directive == llvm::omp::Directive::OMPD_target_enter_data &&
13281339
semaCtx.langOptions().OpenMPVersion >= 52)

flang/lib/Lower/OpenMP/OpenMP.cpp

Lines changed: 126 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2470,6 +2470,121 @@ genSingleOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
24702470
queue, item, clauseOps);
24712471
}
24722472

2473+
static mlir::FlatSymbolRefAttr getOrGenImplicitDefaultDeclareMapper(
2474+
lower::AbstractConverter &converter, mlir::Location loc,
2475+
fir::RecordType recordType, llvm::StringRef mapperNameStr) {
2476+
if (converter.getModuleOp().lookupSymbol(mapperNameStr))
2477+
return mlir::FlatSymbolRefAttr::get(&converter.getMLIRContext(),
2478+
mapperNameStr);
2479+
2480+
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
2481+
2482+
// Save current insertion point before moving to the module scope to create
2483+
// the DeclareMapperOp.
2484+
mlir::OpBuilder::InsertionGuard guard(firOpBuilder);
2485+
2486+
firOpBuilder.setInsertionPointToStart(converter.getModuleOp().getBody());
2487+
auto declMapperOp = firOpBuilder.create<mlir::omp::DeclareMapperOp>(
2488+
loc, mapperNameStr, recordType);
2489+
auto &region = declMapperOp.getRegion();
2490+
firOpBuilder.createBlock(&region);
2491+
auto mapperArg = region.addArgument(firOpBuilder.getRefType(recordType), loc);
2492+
2493+
auto declareOp =
2494+
firOpBuilder.create<hlfir::DeclareOp>(loc, mapperArg, /*uniq_name=*/"");
2495+
2496+
const auto genBoundsOps = [&](mlir::Value mapVal,
2497+
llvm::SmallVectorImpl<mlir::Value> &bounds) {
2498+
fir::ExtendedValue extVal =
2499+
hlfir::translateToExtendedValue(mapVal.getLoc(), firOpBuilder,
2500+
hlfir::Entity{mapVal},
2501+
/*contiguousHint=*/true)
2502+
.first;
2503+
fir::factory::AddrAndBoundsInfo info = fir::factory::getDataOperandBaseAddr(
2504+
firOpBuilder, mapVal, /*isOptional=*/false, mapVal.getLoc());
2505+
bounds = fir::factory::genImplicitBoundsOps<mlir::omp::MapBoundsOp,
2506+
mlir::omp::MapBoundsType>(
2507+
firOpBuilder, info, extVal,
2508+
/*dataExvIsAssumedSize=*/false, mapVal.getLoc());
2509+
};
2510+
2511+
// Return a reference to the contents of a derived type with one field.
2512+
// Also return the field type.
2513+
const auto getFieldRef =
2514+
[&](mlir::Value rec, llvm::StringRef fieldName, mlir::Type fieldTy,
2515+
mlir::Type recType) -> std::tuple<mlir::Value, mlir::Type> {
2516+
mlir::Value field = firOpBuilder.create<fir::FieldIndexOp>(
2517+
loc, fir::FieldType::get(recType.getContext()), fieldName, recType,
2518+
fir::getTypeParams(rec));
2519+
return {firOpBuilder.create<fir::CoordinateOp>(
2520+
loc, firOpBuilder.getRefType(fieldTy), rec, field),
2521+
fieldTy};
2522+
};
2523+
2524+
mlir::omp::DeclareMapperInfoOperands clauseOps;
2525+
llvm::SmallVector<llvm::SmallVector<int64_t>> memberPlacementIndices;
2526+
llvm::SmallVector<mlir::Value> memberMapOps;
2527+
2528+
llvm::omp::OpenMPOffloadMappingFlags mapFlag =
2529+
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO |
2530+
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM |
2531+
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT;
2532+
mlir::omp::VariableCaptureKind captureKind =
2533+
mlir::omp::VariableCaptureKind::ByRef;
2534+
2535+
// Populate the declareMapper region with the map information.
2536+
for (const auto &entry : llvm::enumerate(recordType.getTypeList())) {
2537+
const auto &memberName = entry.value().first;
2538+
const auto &memberType = entry.value().second;
2539+
auto [ref, type] =
2540+
getFieldRef(declareOp.getBase(), memberName, memberType, recordType);
2541+
mlir::FlatSymbolRefAttr mapperId;
2542+
if (auto recType = mlir::dyn_cast<fir::RecordType>(memberType)) {
2543+
std::string mapperIdName =
2544+
recType.getName().str() + llvm::omp::OmpDefaultMapperName;
2545+
if (auto *sym = converter.getCurrentScope().FindSymbol(mapperIdName))
2546+
mapperIdName = converter.mangleName(mapperIdName, sym->owner());
2547+
else if (auto *sym = converter.getCurrentScope().FindSymbol(memberName))
2548+
mapperIdName = converter.mangleName(mapperIdName, sym->owner());
2549+
2550+
mapperId = getOrGenImplicitDefaultDeclareMapper(converter, loc, recType,
2551+
mapperIdName);
2552+
}
2553+
2554+
llvm::SmallVector<mlir::Value> bounds;
2555+
genBoundsOps(ref, bounds);
2556+
mlir::Value mapOp = createMapInfoOp(
2557+
firOpBuilder, loc, ref, /*varPtrPtr=*/mlir::Value{}, /*name=*/"",
2558+
bounds,
2559+
/*members=*/{},
2560+
/*membersIndex=*/mlir::ArrayAttr{},
2561+
static_cast<
2562+
std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
2563+
mapFlag),
2564+
captureKind, ref.getType(), /*partialMap=*/false, mapperId);
2565+
memberMapOps.emplace_back(mapOp);
2566+
memberPlacementIndices.emplace_back(
2567+
llvm::SmallVector<int64_t>{(int64_t)entry.index()});
2568+
}
2569+
2570+
llvm::SmallVector<mlir::Value> bounds;
2571+
genBoundsOps(declareOp.getOriginalBase(), bounds);
2572+
mlir::omp::MapInfoOp mapOp = createMapInfoOp(
2573+
firOpBuilder, loc, declareOp.getOriginalBase(),
2574+
/*varPtrPtr=*/mlir::Value(), /*name=*/"", bounds, memberMapOps,
2575+
firOpBuilder.create2DI64ArrayAttr(memberPlacementIndices),
2576+
static_cast<std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
2577+
mapFlag),
2578+
captureKind, declareOp.getType(0),
2579+
/*partialMap=*/true);
2580+
2581+
clauseOps.mapVars.emplace_back(mapOp);
2582+
2583+
firOpBuilder.create<mlir::omp::DeclareMapperInfoOp>(loc, clauseOps.mapVars);
2584+
return mlir::FlatSymbolRefAttr::get(&converter.getMLIRContext(),
2585+
mapperNameStr);
2586+
}
2587+
24732588
static mlir::omp::TargetOp
24742589
genTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
24752590
lower::StatementContext &stmtCtx,
@@ -2546,15 +2661,25 @@ genTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
25462661
name << sym.name().ToString();
25472662

25482663
mlir::FlatSymbolRefAttr mapperId;
2549-
if (sym.GetType()->category() == semantics::DeclTypeSpec::TypeDerived) {
2664+
if (sym.GetType()->category() == semantics::DeclTypeSpec::TypeDerived &&
2665+
defaultMaps.empty()) {
25502666
auto &typeSpec = sym.GetType()->derivedTypeSpec();
25512667
std::string mapperIdName =
25522668
typeSpec.name().ToString() + llvm::omp::OmpDefaultMapperName;
25532669
if (auto *sym = converter.getCurrentScope().FindSymbol(mapperIdName))
25542670
mapperIdName = converter.mangleName(mapperIdName, sym->owner());
2671+
else
2672+
mapperIdName =
2673+
converter.mangleName(mapperIdName, *typeSpec.GetScope());
2674+
25552675
if (converter.getModuleOp().lookupSymbol(mapperIdName))
25562676
mapperId = mlir::FlatSymbolRefAttr::get(&converter.getMLIRContext(),
25572677
mapperIdName);
2678+
mapperId = getOrGenImplicitDefaultDeclareMapper(
2679+
converter, loc,
2680+
mlir::cast<fir::RecordType>(
2681+
converter.genType(sym.GetType()->derivedTypeSpec())),
2682+
mapperIdName);
25582683
}
25592684

25602685
fir::factory::AddrAndBoundsInfo info =

flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -441,7 +441,7 @@ class MapInfoFinalizationPass
441441
getDescriptorMapType(mapType, target)),
442442
op.getMapCaptureTypeAttr(), /*varPtrPtr=*/mlir::Value{}, newMembers,
443443
newMembersAttr, /*bounds=*/mlir::SmallVector<mlir::Value>{},
444-
/*mapperId*/ mlir::FlatSymbolRefAttr(), op.getNameAttr(),
444+
op.getMapperIdAttr(), op.getNameAttr(),
445445
/*partial_map=*/builder.getBoolAttr(false));
446446
op.replaceAllUsesWith(newDescParentMapOp.getResult());
447447
op->erase();

flang/test/Lower/OpenMP/derived-type-map.f90

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
!RUN: %flang_fc1 -emit-hlfir -fopenmp %s -o - | FileCheck %s
22

3+
!CHECK: omp.declare_mapper @[[MAPPER1:_QQFmaptype_derived_implicit_allocatablescalar_and_array.omp.default.mapper]] : !fir.type<_QFmaptype_derived_implicit_allocatableTscalar_and_array{real:f32,array:!fir.array<10xi32>,int:i32}> {
4+
!CHECK: omp.declare_mapper @[[MAPPER2:_QQFmaptype_derived_implicitscalar_and_array.omp.default.mapper]] : !fir.type<_QFmaptype_derived_implicitTscalar_and_array{real:f32,array:!fir.array<10xi32>,int:i32}> {
35

46
!CHECK: %[[ALLOCA:.*]] = fir.alloca !fir.type<_QFmaptype_derived_implicitTscalar_and_array{real:f32,array:!fir.array<10xi32>,int:i32}> {bindc_name = "scalar_arr", uniq_name = "_QFmaptype_derived_implicitEscalar_arr"}
57
!CHECK: %[[DECLARE:.*]]:2 = hlfir.declare %[[ALLOCA]] {uniq_name = "_QFmaptype_derived_implicitEscalar_arr"} : (!fir.ref<!fir.type<_QFmaptype_derived_implicitTscalar_and_array{real:f32,array:!fir.array<10xi32>,int:i32}>>) -> (!fir.ref<!fir.type<_QFmaptype_derived_implicitTscalar_and_array{real:f32,array:!fir.array<10xi32>,int:i32}>>, !fir.ref<!fir.type<_QFmaptype_derived_implicitTscalar_and_array{real:f32,array:!fir.array<10xi32>,int:i32}>>)
6-
!CHECK: %[[MAP:.*]] = omp.map.info var_ptr(%[[DECLARE]]#1 : !fir.ref<!fir.type<_QFmaptype_derived_implicitTscalar_and_array{real:f32,array:!fir.array<10xi32>,int:i32}>>, !fir.type<_QFmaptype_derived_implicitTscalar_and_array{real:f32,array:!fir.array<10xi32>,int:i32}>) map_clauses(implicit, tofrom) capture(ByRef) -> !fir.ref<!fir.type<_QFmaptype_derived_implicitTscalar_and_array{real:f32,array:!fir.array<10xi32>,int:i32}>> {name = "scalar_arr"}
8+
!CHECK: %[[MAP:.*]] = omp.map.info var_ptr(%[[DECLARE]]#1 : !fir.ref<!fir.type<_QFmaptype_derived_implicitTscalar_and_array{real:f32,array:!fir.array<10xi32>,int:i32}>>, !fir.type<_QFmaptype_derived_implicitTscalar_and_array{real:f32,array:!fir.array<10xi32>,int:i32}>) map_clauses(implicit, tofrom) capture(ByRef) mapper(@[[MAPPER2]]) -> !fir.ref<!fir.type<_QFmaptype_derived_implicitTscalar_and_array{real:f32,array:!fir.array<10xi32>,int:i32}>> {name = "scalar_arr"}
79
!CHECK: omp.target map_entries(%[[MAP]] -> %[[ARG0:.*]] : !fir.ref<!fir.type<_QFmaptype_derived_implicitTscalar_and_array{real:f32,array:!fir.array<10xi32>,int:i32}>>) {
810
subroutine mapType_derived_implicit
911
type :: scalar_and_array
@@ -18,6 +20,24 @@ subroutine mapType_derived_implicit
1820
!$omp end target
1921
end subroutine mapType_derived_implicit
2022

23+
!CHECK: %[[ALLOCA:.*]] = fir.alloca !fir.box<!fir.heap<!fir.type<_QFmaptype_derived_implicit_allocatableTscalar_and_array{real:f32,array:!fir.array<10xi32>,int:i32}>>> {bindc_name = "scalar_arr", uniq_name = "_QFmaptype_derived_implicit_allocatableEscalar_arr"}
24+
!CHECK: %[[DECLARE:.*]]:2 = hlfir.declare %[[ALLOCA]] {fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFmaptype_derived_implicit_allocatableEscalar_arr"} : {{.*}}
25+
!CHECK: %[[MAP:.*]] = omp.map.info var_ptr(%[[DECLARE]]#1 : {{.*}}) map_clauses(implicit, to) capture(ByRef) mapper(@[[MAPPER1]])
26+
!CHECK: omp.target map_entries(%[[MAP]] -> %[[ARG0:.*]] : !fir.ref<!fir.box<!fir.heap<!fir.type<_QFmaptype_derived_implicit_allocatableTscalar_and_array{real:f32,array:!fir.array<10xi32>,int:i32}>>>>, !fir.llvm_ptr<!fir.ref<!fir.type<_QFmaptype_derived_implicit_allocatableTscalar_and_array{real:f32,array:!fir.array<10xi32>,int:i32}>>>) {
27+
subroutine mapType_derived_implicit_allocatable
28+
type :: scalar_and_array
29+
real(4) :: real
30+
integer(4) :: array(10)
31+
integer(4) :: int
32+
end type scalar_and_array
33+
type(scalar_and_array), allocatable :: scalar_arr
34+
35+
allocate (scalar_arr)
36+
!$omp target
37+
scalar_arr%int = 1
38+
!$omp end target
39+
end subroutine mapType_derived_implicit_allocatable
40+
2141
!CHECK: %[[ALLOCA:.*]] = fir.alloca !fir.type<_QFmaptype_derived_explicitTscalar_and_array{real:f32,array:!fir.array<10xi32>,int:i32}> {bindc_name = "scalar_arr", uniq_name = "_QFmaptype_derived_explicitEscalar_arr"}
2242
!CHECK: %[[DECLARE:.*]]:2 = hlfir.declare %[[ALLOCA]] {uniq_name = "_QFmaptype_derived_explicitEscalar_arr"} : (!fir.ref<!fir.type<_QFmaptype_derived_explicitTscalar_and_array{real:f32,array:!fir.array<10xi32>,int:i32}>>) -> (!fir.ref<!fir.type<_QFmaptype_derived_explicitTscalar_and_array{real:f32,array:!fir.array<10xi32>,int:i32}>>, !fir.ref<!fir.type<_QFmaptype_derived_explicitTscalar_and_array{real:f32,array:!fir.array<10xi32>,int:i32}>>)
2343
!CHECK: %[[MAP:.*]] = omp.map.info var_ptr(%[[DECLARE]]#1 : !fir.ref<!fir.type<_QFmaptype_derived_explicitTscalar_and_array{real:f32,array:!fir.array<10xi32>,int:i32}>>, !fir.type<_QFmaptype_derived_explicitTscalar_and_array{real:f32,array:!fir.array<10xi32>,int:i32}>) map_clauses(tofrom) capture(ByRef) -> !fir.ref<!fir.type<_QFmaptype_derived_explicitTscalar_and_array{real:f32,array:!fir.array<10xi32>,int:i32}>> {name = "scalar_arr"}

0 commit comments

Comments
 (0)