4141#include " mlir/Pass/Pass.h"
4242#include " mlir/Support/LLVM.h"
4343#include " llvm/ADT/SmallPtrSet.h"
44+ #include " llvm/ADT/StringSet.h"
4445#include " llvm/Frontend/OpenMP/OMPConstants.h"
46+ #include " llvm/Support/raw_ostream.h"
4547#include < algorithm>
4648#include < cstddef>
4749#include < iterator>
@@ -75,6 +77,112 @@ class MapInfoFinalizationPass
7577 // / | |
7678 std::map<mlir::Operation *, mlir::Value> localBoxAllocas;
7779
80+ // / Return true if the given path exists in a list of paths.
81+ static bool
82+ containsPath (const llvm::SmallVectorImpl<llvm::SmallVector<int64_t >> &paths,
83+ llvm::ArrayRef<int64_t > path) {
84+ return llvm::any_of (paths, [&](const llvm::SmallVector<int64_t > &p) {
85+ return p.size () == path.size () &&
86+ std::equal (p.begin (), p.end (), path.begin ());
87+ });
88+ }
89+
90+ // / Return true if the given path is already present in
91+ // / op.getMembersIndexAttr().
92+ static bool mappedIndexPathExists (mlir::omp::MapInfoOp op,
93+ llvm::ArrayRef<int64_t > indexPath) {
94+ if (mlir::ArrayAttr attr = op.getMembersIndexAttr ()) {
95+ for (mlir::Attribute list : attr) {
96+ auto listAttr = mlir::cast<mlir::ArrayAttr>(list);
97+ if (listAttr.size () != indexPath.size ())
98+ continue ;
99+ bool allEq = true ;
100+ for (auto [i, val] : llvm::enumerate (listAttr)) {
101+ if (mlir::cast<mlir::IntegerAttr>(val).getInt () != indexPath[i]) {
102+ allEq = false ;
103+ break ;
104+ }
105+ }
106+ if (allEq)
107+ return true ;
108+ }
109+ }
110+ return false ;
111+ }
112+
113+ // / Build a compact string key for an index path for set-based
114+ // / deduplication. Format: "N:v0,v1,..." where N is the length.
115+ static void buildPathKey (llvm::ArrayRef<int64_t > path,
116+ llvm::SmallString<64 > &outKey) {
117+ outKey.clear ();
118+ llvm::raw_svector_ostream os (outKey);
119+ os << path.size () << ' :' ;
120+ for (size_t i = 0 ; i < path.size (); ++i) {
121+ if (i)
122+ os << ' ,' ;
123+ os << path[i];
124+ }
125+ }
126+
127+ // / Create the member map for coordRef and append it (and its index
128+ // / path) to the provided new* vectors, if it is not already present.
129+ void appendMemberMapIfNew (
130+ mlir::omp::MapInfoOp op, fir::FirOpBuilder &builder, mlir::Location loc,
131+ mlir::Value coordRef, llvm::ArrayRef<int64_t > indexPath,
132+ llvm::StringRef memberName,
133+ llvm::SmallVectorImpl<mlir::Value> &newMapOpsForFields,
134+ llvm::SmallVectorImpl<llvm::SmallVector<int64_t >> &newMemberIndexPaths) {
135+ // Local de-dup within this op invocation.
136+ if (containsPath (newMemberIndexPaths, indexPath))
137+ return ;
138+ // Global de-dup against already present member indices.
139+ if (mappedIndexPathExists (op, indexPath))
140+ return ;
141+
142+ if (op.getMapperId ()) {
143+ mlir::omp::DeclareMapperOp symbol =
144+ mlir::SymbolTable::lookupNearestSymbolFrom<
145+ mlir::omp::DeclareMapperOp>(op, op.getMapperIdAttr ());
146+ assert (symbol && " missing symbol for declare mapper identifier" );
147+ mlir::omp::DeclareMapperInfoOp mapperInfo = symbol.getDeclareMapperInfo ();
148+ // TODO: Probably a way to cache these keys in someway so we don't
149+ // constantly go through the process of rebuilding them on every check, to
150+ // save some cycles, but it can wait for a subsequent patch.
151+ for (auto v : mapperInfo.getMapVars ()) {
152+ mlir::omp::MapInfoOp map =
153+ mlir::cast<mlir::omp::MapInfoOp>(v.getDefiningOp ());
154+ if (!map.getMembers ().empty () && mappedIndexPathExists (map, indexPath))
155+ return ;
156+ }
157+ }
158+
159+ builder.setInsertionPoint (op);
160+ fir::factory::AddrAndBoundsInfo info = fir::factory::getDataOperandBaseAddr (
161+ builder, coordRef, /* isOptional=*/ false , loc);
162+ llvm::SmallVector<mlir::Value> bounds = fir::factory::genImplicitBoundsOps<
163+ mlir::omp::MapBoundsOp, mlir::omp::MapBoundsType>(
164+ builder, info,
165+ hlfir::translateToExtendedValue (loc, builder, hlfir::Entity{coordRef})
166+ .first ,
167+ /* dataExvIsAssumedSize=*/ false , loc);
168+
169+ mlir::omp::MapInfoOp fieldMapOp = mlir::omp::MapInfoOp::create (
170+ builder, loc, coordRef.getType (), coordRef,
171+ mlir::TypeAttr::get (fir::unwrapRefType (coordRef.getType ())),
172+ op.getMapTypeAttr (),
173+ builder.getAttr <mlir::omp::VariableCaptureKindAttr>(
174+ mlir::omp::VariableCaptureKind::ByRef),
175+ /* varPtrPtr=*/ mlir::Value{}, /* members=*/ mlir::ValueRange{},
176+ /* members_index=*/ mlir::ArrayAttr{}, bounds,
177+ /* mapperId=*/ mlir::FlatSymbolRefAttr (),
178+ builder.getStringAttr (op.getNameAttr ().strref () + " ." + memberName +
179+ " .implicit_map" ),
180+ /* partial_map=*/ builder.getBoolAttr (false ));
181+
182+ newMapOpsForFields.emplace_back (fieldMapOp);
183+ newMemberIndexPaths.emplace_back (indexPath.begin (), indexPath.end ());
184+ }
185+
78186 // / getMemberUserList gathers all users of a particular MapInfoOp that are
79187 // / other MapInfoOp's and places them into the mapMemberUsers list, which
80188 // / records the map that the current argument MapInfoOp "op" is part of
@@ -363,7 +471,7 @@ class MapInfoFinalizationPass
363471 mlir::ArrayAttr newMembersAttr;
364472 mlir::SmallVector<mlir::Value> newMembers;
365473 llvm::SmallVector<llvm::SmallVector<int64_t >> memberIndices;
366- bool IsHasDeviceAddr = isHasDeviceAddr (op, target);
474+ bool isHasDeviceAddrFlag = isHasDeviceAddr (op, target);
367475
368476 if (!mapMemberUsers.empty () || !op.getMembers ().empty ())
369477 getMemberIndicesAsVectors (
@@ -406,7 +514,7 @@ class MapInfoFinalizationPass
406514 mapUser.parent .getMembersMutable ().assign (newMemberOps);
407515 mapUser.parent .setMembersIndexAttr (
408516 builder.create2DI64ArrayAttr (memberIndices));
409- } else if (!IsHasDeviceAddr ) {
517+ } else if (!isHasDeviceAddrFlag ) {
410518 auto baseAddr =
411519 genBaseAddrMap (descriptor, op.getBounds (), op.getMapType (), builder);
412520 newMembers.push_back (baseAddr);
@@ -429,7 +537,7 @@ class MapInfoFinalizationPass
429537 // The contents of the descriptor (the base address in particular) will
430538 // remain unchanged though.
431539 uint64_t mapType = op.getMapType ();
432- if (IsHasDeviceAddr ) {
540+ if (isHasDeviceAddrFlag ) {
433541 mapType |= llvm::to_underlying (
434542 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS);
435543 }
@@ -701,105 +809,143 @@ class MapInfoFinalizationPass
701809
702810 auto recordType = mlir::cast<fir::RecordType>(underlyingType);
703811 llvm::SmallVector<mlir::Value> newMapOpsForFields;
704- llvm::SmallVector<int64_t > fieldIndicies ;
812+ llvm::SmallVector<llvm::SmallVector< int64_t >> newMemberIndexPaths ;
705813
814+ // 1) Handle direct top-level allocatable fields.
706815 for (auto fieldMemTyPair : recordType.getTypeList ()) {
707816 auto &field = fieldMemTyPair.first ;
708817 auto memTy = fieldMemTyPair.second ;
709818
710- bool shouldMapField =
711- llvm::find_if (mapVarForwardSlice, [&](mlir::Operation *sliceOp) {
712- if (!fir::isAllocatableType (memTy))
713- return false ;
714-
715- auto designateOp = mlir::dyn_cast<hlfir::DesignateOp>(sliceOp);
716- if (!designateOp)
717- return false ;
718-
719- return designateOp.getComponent () &&
720- designateOp.getComponent ()->strref () == field;
721- }) != mapVarForwardSlice.end ();
722-
723- // TODO Handle recursive record types. Adapting
724- // `createParentSymAndGenIntermediateMaps` to work direclty on MLIR
725- // entities might be helpful here.
726-
727- if (!shouldMapField)
819+ if (!fir::isAllocatableType (memTy))
728820 continue ;
729821
730- int32_t fieldIdx = recordType.getFieldIndex (field);
731- bool alreadyMapped = [&]() {
732- if (op.getMembersIndexAttr ())
733- for (auto indexList : op.getMembersIndexAttr ()) {
734- auto indexListAttr = mlir::cast<mlir::ArrayAttr>(indexList);
735- if (indexListAttr.size () == 1 &&
736- mlir::cast<mlir::IntegerAttr>(indexListAttr[0 ]).getInt () ==
737- fieldIdx)
738- return true ;
739- }
740-
741- return false ;
742- }();
743-
744- if (alreadyMapped)
822+ bool referenced = llvm::any_of (mapVarForwardSlice, [&](auto *opv) {
823+ auto designateOp = mlir::dyn_cast<hlfir::DesignateOp>(opv);
824+ return designateOp && designateOp.getComponent () &&
825+ designateOp.getComponent ()->strref () == field;
826+ });
827+ if (!referenced)
745828 continue ;
746829
830+ int32_t fieldIdx = recordType.getFieldIndex (field);
747831 builder.setInsertionPoint (op);
748832 fir::IntOrValue idxConst =
749833 mlir::IntegerAttr::get (builder.getI32Type (), fieldIdx);
750834 auto fieldCoord = fir::CoordinateOp::create (
751835 builder, op.getLoc (), builder.getRefType (memTy), op.getVarPtr (),
752836 llvm::SmallVector<fir::IntOrValue, 1 >{idxConst});
753- fir::factory::AddrAndBoundsInfo info =
754- fir::factory::getDataOperandBaseAddr (
755- builder, fieldCoord, /* isOptional=*/ false , op.getLoc ());
756- llvm::SmallVector<mlir::Value> bounds =
757- fir::factory::genImplicitBoundsOps<mlir::omp::MapBoundsOp,
758- mlir::omp::MapBoundsType>(
759- builder, info,
760- hlfir::translateToExtendedValue (op.getLoc (), builder,
761- hlfir::Entity{fieldCoord})
762- .first ,
763- /* dataExvIsAssumedSize=*/ false , op.getLoc ());
764-
765- mlir::omp::MapInfoOp fieldMapOp = mlir::omp::MapInfoOp::create (
766- builder, op.getLoc (), fieldCoord.getResult ().getType (),
767- fieldCoord.getResult (),
768- mlir::TypeAttr::get (
769- fir::unwrapRefType (fieldCoord.getResult ().getType ())),
770- op.getMapTypeAttr (),
771- builder.getAttr <mlir::omp::VariableCaptureKindAttr>(
772- mlir::omp::VariableCaptureKind::ByRef),
773- /* varPtrPtr=*/ mlir::Value{}, /* members=*/ mlir::ValueRange{},
774- /* members_index=*/ mlir::ArrayAttr{}, bounds,
775- /* mapperId=*/ mlir::FlatSymbolRefAttr (),
776- builder.getStringAttr (op.getNameAttr ().strref () + " ." + field +
777- " .implicit_map" ),
778- /* partial_map=*/ builder.getBoolAttr (false ));
779- newMapOpsForFields.emplace_back (fieldMapOp);
780- fieldIndicies.emplace_back (fieldIdx);
837+ int64_t fieldIdx64 = static_cast <int64_t >(fieldIdx);
838+ llvm::SmallVector<int64_t , 1 > idxPath{fieldIdx64};
839+ appendMemberMapIfNew (op, builder, op.getLoc (), fieldCoord, idxPath,
840+ field, newMapOpsForFields, newMemberIndexPaths);
841+ }
842+
843+ // Handle nested allocatable fields along any component chain
844+ // referenced in the region via HLFIR designates.
845+ llvm::SmallVector<llvm::SmallVector<int64_t >> seenIndexPaths;
846+ for (mlir::Operation *sliceOp : mapVarForwardSlice) {
847+ auto designateOp = mlir::dyn_cast<hlfir::DesignateOp>(sliceOp);
848+ if (!designateOp || !designateOp.getComponent ())
849+ continue ;
850+ llvm::SmallVector<llvm::StringRef> compPathReversed;
851+ compPathReversed.push_back (designateOp.getComponent ()->strref ());
852+ mlir::Value curBase = designateOp.getMemref ();
853+ bool rootedAtMapArg = false ;
854+ while (true ) {
855+ if (auto parentDes = curBase.getDefiningOp <hlfir::DesignateOp>()) {
856+ if (!parentDes.getComponent ())
857+ break ;
858+ compPathReversed.push_back (parentDes.getComponent ()->strref ());
859+ curBase = parentDes.getMemref ();
860+ continue ;
861+ }
862+ if (auto decl = curBase.getDefiningOp <hlfir::DeclareOp>()) {
863+ if (auto barg =
864+ mlir::dyn_cast<mlir::BlockArgument>(decl.getMemref ()))
865+ rootedAtMapArg = (barg == opBlockArg);
866+ } else if (auto blockArg =
867+ mlir::dyn_cast_or_null<mlir::BlockArgument>(
868+ curBase)) {
869+ rootedAtMapArg = (blockArg == opBlockArg);
870+ }
871+ break ;
872+ }
873+ // Only process nested paths (2+ components). Single-component paths
874+ // for direct fields are handled above.
875+ if (!rootedAtMapArg || compPathReversed.size () < 2 )
876+ continue ;
877+ builder.setInsertionPoint (op);
878+ llvm::SmallVector<int64_t > indexPath;
879+ mlir::Type curTy = underlyingType;
880+ mlir::Value coordRef = op.getVarPtr ();
881+ bool validPath = true ;
882+ for (llvm::StringRef compName : llvm::reverse (compPathReversed)) {
883+ auto recTy = mlir::dyn_cast<fir::RecordType>(curTy);
884+ if (!recTy) {
885+ validPath = false ;
886+ break ;
887+ }
888+ int32_t idx = recTy.getFieldIndex (compName);
889+ if (idx < 0 ) {
890+ validPath = false ;
891+ break ;
892+ }
893+ indexPath.push_back (idx);
894+ mlir::Type memTy = recTy.getType (idx);
895+ fir::IntOrValue idxConst =
896+ mlir::IntegerAttr::get (builder.getI32Type (), idx);
897+ coordRef = fir::CoordinateOp::create (
898+ builder, op.getLoc (), builder.getRefType (memTy), coordRef,
899+ llvm::SmallVector<fir::IntOrValue, 1 >{idxConst});
900+ curTy = memTy;
901+ }
902+ if (!validPath)
903+ continue ;
904+ if (auto finalRefTy =
905+ mlir::dyn_cast<fir::ReferenceType>(coordRef.getType ())) {
906+ mlir::Type eleTy = finalRefTy.getElementType ();
907+ if (fir::isAllocatableType (eleTy)) {
908+ if (!containsPath (seenIndexPaths, indexPath)) {
909+ seenIndexPaths.emplace_back (indexPath.begin (), indexPath.end ());
910+ appendMemberMapIfNew (op, builder, op.getLoc (), coordRef,
911+ indexPath, compPathReversed.front (),
912+ newMapOpsForFields, newMemberIndexPaths);
913+ }
914+ }
915+ }
781916 }
782917
783918 if (newMapOpsForFields.empty ())
784919 return mlir::WalkResult::advance ();
785920
786- op.getMembersMutable ().append (newMapOpsForFields);
921+ // Deduplicate by index path to avoid emitting duplicate members for
922+ // the same component. Use a set-based key to keep this near O(n).
923+ llvm::SmallVector<mlir::Value> dedupMapOps;
924+ llvm::SmallVector<llvm::SmallVector<int64_t >> dedupIndexPaths;
925+ llvm::StringSet<> seenKeys;
926+ for (auto [i, mapOp] : llvm::enumerate (newMapOpsForFields)) {
927+ const auto &path = newMemberIndexPaths[i];
928+ llvm::SmallString<64 > key;
929+ buildPathKey (path, key);
930+ if (seenKeys.contains (key))
931+ continue ;
932+ seenKeys.insert (key);
933+ dedupMapOps.push_back (mapOp);
934+ dedupIndexPaths.emplace_back (path.begin (), path.end ());
935+ }
936+ op.getMembersMutable ().append (dedupMapOps);
787937 llvm::SmallVector<llvm::SmallVector<int64_t >> newMemberIndices;
788- mlir::ArrayAttr oldMembersIdxAttr = op.getMembersIndexAttr ();
789-
790- if (oldMembersIdxAttr)
791- for (mlir::Attribute indexList : oldMembersIdxAttr) {
938+ if (mlir::ArrayAttr oldAttr = op.getMembersIndexAttr ())
939+ for (mlir::Attribute indexList : oldAttr) {
792940 llvm::SmallVector<int64_t > listVec;
793941
794942 for (mlir::Attribute index : mlir::cast<mlir::ArrayAttr>(indexList))
795943 listVec.push_back (mlir::cast<mlir::IntegerAttr>(index).getInt ());
796944
797945 newMemberIndices.emplace_back (std::move (listVec));
798946 }
799-
800- for (int64_t newFieldIdx : fieldIndicies)
801- newMemberIndices.emplace_back (
802- llvm::SmallVector<int64_t >(1 , newFieldIdx));
947+ for (auto &path : dedupIndexPaths)
948+ newMemberIndices.emplace_back (path);
803949
804950 op.setMembersIndexAttr (builder.create2DI64ArrayAttr (newMemberIndices));
805951 op.setPartialMap (true );
0 commit comments