Skip to content

Commit ed12dc5

Browse files
authored
[Flang][OpenMP] Implicitly map nested allocatable components in derived types (#160766)
This PR adds support for nested derived types and their mappers to the MapInfoFinalization pass. - Generalize MapInfoFinalization to add child maps for arbitrarily nested allocatables when a derived object is mapped via declare mapper. - Traverse HLFIR designates rooted at the target block arg and build full coordinate_of chains; append members with correct membersIndex. This fixes #156461.
1 parent afb2628 commit ed12dc5

File tree

3 files changed

+302
-74
lines changed

3 files changed

+302
-74
lines changed

flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp

Lines changed: 220 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,9 @@
4141
#include "mlir/Pass/Pass.h"
4242
#include "mlir/Support/LLVM.h"
4343
#include "llvm/ADT/SmallPtrSet.h"
44+
#include "llvm/ADT/StringSet.h"
4445
#include "llvm/Frontend/OpenMP/OMPConstants.h"
46+
#include "llvm/Support/raw_ostream.h"
4547
#include <algorithm>
4648
#include <cstddef>
4749
#include <iterator>
@@ -75,6 +77,112 @@ class MapInfoFinalizationPass
7577
/// | |
7678
std::map<mlir::Operation *, mlir::Value> localBoxAllocas;
7779

80+
/// Return true if the given path exists in a list of paths.
81+
static bool
82+
containsPath(const llvm::SmallVectorImpl<llvm::SmallVector<int64_t>> &paths,
83+
llvm::ArrayRef<int64_t> path) {
84+
return llvm::any_of(paths, [&](const llvm::SmallVector<int64_t> &p) {
85+
return p.size() == path.size() &&
86+
std::equal(p.begin(), p.end(), path.begin());
87+
});
88+
}
89+
90+
/// Return true if the given path is already present in
91+
/// op.getMembersIndexAttr().
92+
static bool mappedIndexPathExists(mlir::omp::MapInfoOp op,
93+
llvm::ArrayRef<int64_t> indexPath) {
94+
if (mlir::ArrayAttr attr = op.getMembersIndexAttr()) {
95+
for (mlir::Attribute list : attr) {
96+
auto listAttr = mlir::cast<mlir::ArrayAttr>(list);
97+
if (listAttr.size() != indexPath.size())
98+
continue;
99+
bool allEq = true;
100+
for (auto [i, val] : llvm::enumerate(listAttr)) {
101+
if (mlir::cast<mlir::IntegerAttr>(val).getInt() != indexPath[i]) {
102+
allEq = false;
103+
break;
104+
}
105+
}
106+
if (allEq)
107+
return true;
108+
}
109+
}
110+
return false;
111+
}
112+
113+
/// Build a compact string key for an index path for set-based
114+
/// deduplication. Format: "N:v0,v1,..." where N is the length.
115+
static void buildPathKey(llvm::ArrayRef<int64_t> path,
116+
llvm::SmallString<64> &outKey) {
117+
outKey.clear();
118+
llvm::raw_svector_ostream os(outKey);
119+
os << path.size() << ':';
120+
for (size_t i = 0; i < path.size(); ++i) {
121+
if (i)
122+
os << ',';
123+
os << path[i];
124+
}
125+
}
126+
127+
/// Create the member map for coordRef and append it (and its index
128+
/// path) to the provided new* vectors, if it is not already present.
129+
void appendMemberMapIfNew(
130+
mlir::omp::MapInfoOp op, fir::FirOpBuilder &builder, mlir::Location loc,
131+
mlir::Value coordRef, llvm::ArrayRef<int64_t> indexPath,
132+
llvm::StringRef memberName,
133+
llvm::SmallVectorImpl<mlir::Value> &newMapOpsForFields,
134+
llvm::SmallVectorImpl<llvm::SmallVector<int64_t>> &newMemberIndexPaths) {
135+
// Local de-dup within this op invocation.
136+
if (containsPath(newMemberIndexPaths, indexPath))
137+
return;
138+
// Global de-dup against already present member indices.
139+
if (mappedIndexPathExists(op, indexPath))
140+
return;
141+
142+
if (op.getMapperId()) {
143+
mlir::omp::DeclareMapperOp symbol =
144+
mlir::SymbolTable::lookupNearestSymbolFrom<
145+
mlir::omp::DeclareMapperOp>(op, op.getMapperIdAttr());
146+
assert(symbol && "missing symbol for declare mapper identifier");
147+
mlir::omp::DeclareMapperInfoOp mapperInfo = symbol.getDeclareMapperInfo();
148+
// TODO: Probably a way to cache these keys in someway so we don't
149+
// constantly go through the process of rebuilding them on every check, to
150+
// save some cycles, but it can wait for a subsequent patch.
151+
for (auto v : mapperInfo.getMapVars()) {
152+
mlir::omp::MapInfoOp map =
153+
mlir::cast<mlir::omp::MapInfoOp>(v.getDefiningOp());
154+
if (!map.getMembers().empty() && mappedIndexPathExists(map, indexPath))
155+
return;
156+
}
157+
}
158+
159+
builder.setInsertionPoint(op);
160+
fir::factory::AddrAndBoundsInfo info = fir::factory::getDataOperandBaseAddr(
161+
builder, coordRef, /*isOptional=*/false, loc);
162+
llvm::SmallVector<mlir::Value> bounds = fir::factory::genImplicitBoundsOps<
163+
mlir::omp::MapBoundsOp, mlir::omp::MapBoundsType>(
164+
builder, info,
165+
hlfir::translateToExtendedValue(loc, builder, hlfir::Entity{coordRef})
166+
.first,
167+
/*dataExvIsAssumedSize=*/false, loc);
168+
169+
mlir::omp::MapInfoOp fieldMapOp = mlir::omp::MapInfoOp::create(
170+
builder, loc, coordRef.getType(), coordRef,
171+
mlir::TypeAttr::get(fir::unwrapRefType(coordRef.getType())),
172+
op.getMapTypeAttr(),
173+
builder.getAttr<mlir::omp::VariableCaptureKindAttr>(
174+
mlir::omp::VariableCaptureKind::ByRef),
175+
/*varPtrPtr=*/mlir::Value{}, /*members=*/mlir::ValueRange{},
176+
/*members_index=*/mlir::ArrayAttr{}, bounds,
177+
/*mapperId=*/mlir::FlatSymbolRefAttr(),
178+
builder.getStringAttr(op.getNameAttr().strref() + "." + memberName +
179+
".implicit_map"),
180+
/*partial_map=*/builder.getBoolAttr(false));
181+
182+
newMapOpsForFields.emplace_back(fieldMapOp);
183+
newMemberIndexPaths.emplace_back(indexPath.begin(), indexPath.end());
184+
}
185+
78186
/// getMemberUserList gathers all users of a particular MapInfoOp that are
79187
/// other MapInfoOp's and places them into the mapMemberUsers list, which
80188
/// records the map that the current argument MapInfoOp "op" is part of
@@ -363,7 +471,7 @@ class MapInfoFinalizationPass
363471
mlir::ArrayAttr newMembersAttr;
364472
mlir::SmallVector<mlir::Value> newMembers;
365473
llvm::SmallVector<llvm::SmallVector<int64_t>> memberIndices;
366-
bool IsHasDeviceAddr = isHasDeviceAddr(op, target);
474+
bool isHasDeviceAddrFlag = isHasDeviceAddr(op, target);
367475

368476
if (!mapMemberUsers.empty() || !op.getMembers().empty())
369477
getMemberIndicesAsVectors(
@@ -406,7 +514,7 @@ class MapInfoFinalizationPass
406514
mapUser.parent.getMembersMutable().assign(newMemberOps);
407515
mapUser.parent.setMembersIndexAttr(
408516
builder.create2DI64ArrayAttr(memberIndices));
409-
} else if (!IsHasDeviceAddr) {
517+
} else if (!isHasDeviceAddrFlag) {
410518
auto baseAddr =
411519
genBaseAddrMap(descriptor, op.getBounds(), op.getMapType(), builder);
412520
newMembers.push_back(baseAddr);
@@ -429,7 +537,7 @@ class MapInfoFinalizationPass
429537
// The contents of the descriptor (the base address in particular) will
430538
// remain unchanged though.
431539
uint64_t mapType = op.getMapType();
432-
if (IsHasDeviceAddr) {
540+
if (isHasDeviceAddrFlag) {
433541
mapType |= llvm::to_underlying(
434542
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS);
435543
}
@@ -701,105 +809,143 @@ class MapInfoFinalizationPass
701809

702810
auto recordType = mlir::cast<fir::RecordType>(underlyingType);
703811
llvm::SmallVector<mlir::Value> newMapOpsForFields;
704-
llvm::SmallVector<int64_t> fieldIndicies;
812+
llvm::SmallVector<llvm::SmallVector<int64_t>> newMemberIndexPaths;
705813

814+
// 1) Handle direct top-level allocatable fields.
706815
for (auto fieldMemTyPair : recordType.getTypeList()) {
707816
auto &field = fieldMemTyPair.first;
708817
auto memTy = fieldMemTyPair.second;
709818

710-
bool shouldMapField =
711-
llvm::find_if(mapVarForwardSlice, [&](mlir::Operation *sliceOp) {
712-
if (!fir::isAllocatableType(memTy))
713-
return false;
714-
715-
auto designateOp = mlir::dyn_cast<hlfir::DesignateOp>(sliceOp);
716-
if (!designateOp)
717-
return false;
718-
719-
return designateOp.getComponent() &&
720-
designateOp.getComponent()->strref() == field;
721-
}) != mapVarForwardSlice.end();
722-
723-
// TODO Handle recursive record types. Adapting
724-
// `createParentSymAndGenIntermediateMaps` to work direclty on MLIR
725-
// entities might be helpful here.
726-
727-
if (!shouldMapField)
819+
if (!fir::isAllocatableType(memTy))
728820
continue;
729821

730-
int32_t fieldIdx = recordType.getFieldIndex(field);
731-
bool alreadyMapped = [&]() {
732-
if (op.getMembersIndexAttr())
733-
for (auto indexList : op.getMembersIndexAttr()) {
734-
auto indexListAttr = mlir::cast<mlir::ArrayAttr>(indexList);
735-
if (indexListAttr.size() == 1 &&
736-
mlir::cast<mlir::IntegerAttr>(indexListAttr[0]).getInt() ==
737-
fieldIdx)
738-
return true;
739-
}
740-
741-
return false;
742-
}();
743-
744-
if (alreadyMapped)
822+
bool referenced = llvm::any_of(mapVarForwardSlice, [&](auto *opv) {
823+
auto designateOp = mlir::dyn_cast<hlfir::DesignateOp>(opv);
824+
return designateOp && designateOp.getComponent() &&
825+
designateOp.getComponent()->strref() == field;
826+
});
827+
if (!referenced)
745828
continue;
746829

830+
int32_t fieldIdx = recordType.getFieldIndex(field);
747831
builder.setInsertionPoint(op);
748832
fir::IntOrValue idxConst =
749833
mlir::IntegerAttr::get(builder.getI32Type(), fieldIdx);
750834
auto fieldCoord = fir::CoordinateOp::create(
751835
builder, op.getLoc(), builder.getRefType(memTy), op.getVarPtr(),
752836
llvm::SmallVector<fir::IntOrValue, 1>{idxConst});
753-
fir::factory::AddrAndBoundsInfo info =
754-
fir::factory::getDataOperandBaseAddr(
755-
builder, fieldCoord, /*isOptional=*/false, op.getLoc());
756-
llvm::SmallVector<mlir::Value> bounds =
757-
fir::factory::genImplicitBoundsOps<mlir::omp::MapBoundsOp,
758-
mlir::omp::MapBoundsType>(
759-
builder, info,
760-
hlfir::translateToExtendedValue(op.getLoc(), builder,
761-
hlfir::Entity{fieldCoord})
762-
.first,
763-
/*dataExvIsAssumedSize=*/false, op.getLoc());
764-
765-
mlir::omp::MapInfoOp fieldMapOp = mlir::omp::MapInfoOp::create(
766-
builder, op.getLoc(), fieldCoord.getResult().getType(),
767-
fieldCoord.getResult(),
768-
mlir::TypeAttr::get(
769-
fir::unwrapRefType(fieldCoord.getResult().getType())),
770-
op.getMapTypeAttr(),
771-
builder.getAttr<mlir::omp::VariableCaptureKindAttr>(
772-
mlir::omp::VariableCaptureKind::ByRef),
773-
/*varPtrPtr=*/mlir::Value{}, /*members=*/mlir::ValueRange{},
774-
/*members_index=*/mlir::ArrayAttr{}, bounds,
775-
/*mapperId=*/mlir::FlatSymbolRefAttr(),
776-
builder.getStringAttr(op.getNameAttr().strref() + "." + field +
777-
".implicit_map"),
778-
/*partial_map=*/builder.getBoolAttr(false));
779-
newMapOpsForFields.emplace_back(fieldMapOp);
780-
fieldIndicies.emplace_back(fieldIdx);
837+
int64_t fieldIdx64 = static_cast<int64_t>(fieldIdx);
838+
llvm::SmallVector<int64_t, 1> idxPath{fieldIdx64};
839+
appendMemberMapIfNew(op, builder, op.getLoc(), fieldCoord, idxPath,
840+
field, newMapOpsForFields, newMemberIndexPaths);
841+
}
842+
843+
// Handle nested allocatable fields along any component chain
844+
// referenced in the region via HLFIR designates.
845+
llvm::SmallVector<llvm::SmallVector<int64_t>> seenIndexPaths;
846+
for (mlir::Operation *sliceOp : mapVarForwardSlice) {
847+
auto designateOp = mlir::dyn_cast<hlfir::DesignateOp>(sliceOp);
848+
if (!designateOp || !designateOp.getComponent())
849+
continue;
850+
llvm::SmallVector<llvm::StringRef> compPathReversed;
851+
compPathReversed.push_back(designateOp.getComponent()->strref());
852+
mlir::Value curBase = designateOp.getMemref();
853+
bool rootedAtMapArg = false;
854+
while (true) {
855+
if (auto parentDes = curBase.getDefiningOp<hlfir::DesignateOp>()) {
856+
if (!parentDes.getComponent())
857+
break;
858+
compPathReversed.push_back(parentDes.getComponent()->strref());
859+
curBase = parentDes.getMemref();
860+
continue;
861+
}
862+
if (auto decl = curBase.getDefiningOp<hlfir::DeclareOp>()) {
863+
if (auto barg =
864+
mlir::dyn_cast<mlir::BlockArgument>(decl.getMemref()))
865+
rootedAtMapArg = (barg == opBlockArg);
866+
} else if (auto blockArg =
867+
mlir::dyn_cast_or_null<mlir::BlockArgument>(
868+
curBase)) {
869+
rootedAtMapArg = (blockArg == opBlockArg);
870+
}
871+
break;
872+
}
873+
// Only process nested paths (2+ components). Single-component paths
874+
// for direct fields are handled above.
875+
if (!rootedAtMapArg || compPathReversed.size() < 2)
876+
continue;
877+
builder.setInsertionPoint(op);
878+
llvm::SmallVector<int64_t> indexPath;
879+
mlir::Type curTy = underlyingType;
880+
mlir::Value coordRef = op.getVarPtr();
881+
bool validPath = true;
882+
for (llvm::StringRef compName : llvm::reverse(compPathReversed)) {
883+
auto recTy = mlir::dyn_cast<fir::RecordType>(curTy);
884+
if (!recTy) {
885+
validPath = false;
886+
break;
887+
}
888+
int32_t idx = recTy.getFieldIndex(compName);
889+
if (idx < 0) {
890+
validPath = false;
891+
break;
892+
}
893+
indexPath.push_back(idx);
894+
mlir::Type memTy = recTy.getType(idx);
895+
fir::IntOrValue idxConst =
896+
mlir::IntegerAttr::get(builder.getI32Type(), idx);
897+
coordRef = fir::CoordinateOp::create(
898+
builder, op.getLoc(), builder.getRefType(memTy), coordRef,
899+
llvm::SmallVector<fir::IntOrValue, 1>{idxConst});
900+
curTy = memTy;
901+
}
902+
if (!validPath)
903+
continue;
904+
if (auto finalRefTy =
905+
mlir::dyn_cast<fir::ReferenceType>(coordRef.getType())) {
906+
mlir::Type eleTy = finalRefTy.getElementType();
907+
if (fir::isAllocatableType(eleTy)) {
908+
if (!containsPath(seenIndexPaths, indexPath)) {
909+
seenIndexPaths.emplace_back(indexPath.begin(), indexPath.end());
910+
appendMemberMapIfNew(op, builder, op.getLoc(), coordRef,
911+
indexPath, compPathReversed.front(),
912+
newMapOpsForFields, newMemberIndexPaths);
913+
}
914+
}
915+
}
781916
}
782917

783918
if (newMapOpsForFields.empty())
784919
return mlir::WalkResult::advance();
785920

786-
op.getMembersMutable().append(newMapOpsForFields);
921+
// Deduplicate by index path to avoid emitting duplicate members for
922+
// the same component. Use a set-based key to keep this near O(n).
923+
llvm::SmallVector<mlir::Value> dedupMapOps;
924+
llvm::SmallVector<llvm::SmallVector<int64_t>> dedupIndexPaths;
925+
llvm::StringSet<> seenKeys;
926+
for (auto [i, mapOp] : llvm::enumerate(newMapOpsForFields)) {
927+
const auto &path = newMemberIndexPaths[i];
928+
llvm::SmallString<64> key;
929+
buildPathKey(path, key);
930+
if (seenKeys.contains(key))
931+
continue;
932+
seenKeys.insert(key);
933+
dedupMapOps.push_back(mapOp);
934+
dedupIndexPaths.emplace_back(path.begin(), path.end());
935+
}
936+
op.getMembersMutable().append(dedupMapOps);
787937
llvm::SmallVector<llvm::SmallVector<int64_t>> newMemberIndices;
788-
mlir::ArrayAttr oldMembersIdxAttr = op.getMembersIndexAttr();
789-
790-
if (oldMembersIdxAttr)
791-
for (mlir::Attribute indexList : oldMembersIdxAttr) {
938+
if (mlir::ArrayAttr oldAttr = op.getMembersIndexAttr())
939+
for (mlir::Attribute indexList : oldAttr) {
792940
llvm::SmallVector<int64_t> listVec;
793941

794942
for (mlir::Attribute index : mlir::cast<mlir::ArrayAttr>(indexList))
795943
listVec.push_back(mlir::cast<mlir::IntegerAttr>(index).getInt());
796944

797945
newMemberIndices.emplace_back(std::move(listVec));
798946
}
799-
800-
for (int64_t newFieldIdx : fieldIndicies)
801-
newMemberIndices.emplace_back(
802-
llvm::SmallVector<int64_t>(1, newFieldIdx));
947+
for (auto &path : dedupIndexPaths)
948+
newMemberIndices.emplace_back(path);
803949

804950
op.setMembersIndexAttr(builder.create2DI64ArrayAttr(newMemberIndices));
805951
op.setPartialMap(true);

0 commit comments

Comments
 (0)