41
41
#include " mlir/Pass/Pass.h"
42
42
#include " mlir/Support/LLVM.h"
43
43
#include " llvm/ADT/SmallPtrSet.h"
44
+ #include " llvm/ADT/StringSet.h"
44
45
#include " llvm/Frontend/OpenMP/OMPConstants.h"
46
+ #include " llvm/Support/raw_ostream.h"
45
47
#include < algorithm>
46
48
#include < cstddef>
47
49
#include < iterator>
@@ -75,6 +77,112 @@ class MapInfoFinalizationPass
75
77
// / | |
76
78
std::map<mlir::Operation *, mlir::Value> localBoxAllocas;
77
79
80
+ // / Return true if the given path exists in a list of paths.
81
+ static bool
82
+ containsPath (const llvm::SmallVectorImpl<llvm::SmallVector<int64_t >> &paths,
83
+ llvm::ArrayRef<int64_t > path) {
84
+ return llvm::any_of (paths, [&](const llvm::SmallVector<int64_t > &p) {
85
+ return p.size () == path.size () &&
86
+ std::equal (p.begin (), p.end (), path.begin ());
87
+ });
88
+ }
89
+
90
+ // / Return true if the given path is already present in
91
+ // / op.getMembersIndexAttr().
92
+ static bool mappedIndexPathExists (mlir::omp::MapInfoOp op,
93
+ llvm::ArrayRef<int64_t > indexPath) {
94
+ if (mlir::ArrayAttr attr = op.getMembersIndexAttr ()) {
95
+ for (mlir::Attribute list : attr) {
96
+ auto listAttr = mlir::cast<mlir::ArrayAttr>(list);
97
+ if (listAttr.size () != indexPath.size ())
98
+ continue ;
99
+ bool allEq = true ;
100
+ for (auto [i, val] : llvm::enumerate (listAttr)) {
101
+ if (mlir::cast<mlir::IntegerAttr>(val).getInt () != indexPath[i]) {
102
+ allEq = false ;
103
+ break ;
104
+ }
105
+ }
106
+ if (allEq)
107
+ return true ;
108
+ }
109
+ }
110
+ return false ;
111
+ }
112
+
113
+ // / Build a compact string key for an index path for set-based
114
+ // / deduplication. Format: "N:v0,v1,..." where N is the length.
115
+ static void buildPathKey (llvm::ArrayRef<int64_t > path,
116
+ llvm::SmallString<64 > &outKey) {
117
+ outKey.clear ();
118
+ llvm::raw_svector_ostream os (outKey);
119
+ os << path.size () << ' :' ;
120
+ for (size_t i = 0 ; i < path.size (); ++i) {
121
+ if (i)
122
+ os << ' ,' ;
123
+ os << path[i];
124
+ }
125
+ }
126
+
127
+ // / Create the member map for coordRef and append it (and its index
128
+ // / path) to the provided new* vectors, if it is not already present.
129
+ void appendMemberMapIfNew (
130
+ mlir::omp::MapInfoOp op, fir::FirOpBuilder &builder, mlir::Location loc,
131
+ mlir::Value coordRef, llvm::ArrayRef<int64_t > indexPath,
132
+ llvm::StringRef memberName,
133
+ llvm::SmallVectorImpl<mlir::Value> &newMapOpsForFields,
134
+ llvm::SmallVectorImpl<llvm::SmallVector<int64_t >> &newMemberIndexPaths) {
135
+ // Local de-dup within this op invocation.
136
+ if (containsPath (newMemberIndexPaths, indexPath))
137
+ return ;
138
+ // Global de-dup against already present member indices.
139
+ if (mappedIndexPathExists (op, indexPath))
140
+ return ;
141
+
142
+ if (op.getMapperId ()) {
143
+ mlir::omp::DeclareMapperOp symbol =
144
+ mlir::SymbolTable::lookupNearestSymbolFrom<
145
+ mlir::omp::DeclareMapperOp>(op, op.getMapperIdAttr ());
146
+ assert (symbol && " missing symbol for declare mapper identifier" );
147
+ mlir::omp::DeclareMapperInfoOp mapperInfo = symbol.getDeclareMapperInfo ();
148
+ // TODO: Probably a way to cache these keys in someway so we don't
149
+ // constantly go through the process of rebuilding them on every check, to
150
+ // save some cycles, but it can wait for a subsequent patch.
151
+ for (auto v : mapperInfo.getMapVars ()) {
152
+ mlir::omp::MapInfoOp map =
153
+ mlir::cast<mlir::omp::MapInfoOp>(v.getDefiningOp ());
154
+ if (!map.getMembers ().empty () && mappedIndexPathExists (map, indexPath))
155
+ return ;
156
+ }
157
+ }
158
+
159
+ builder.setInsertionPoint (op);
160
+ fir::factory::AddrAndBoundsInfo info = fir::factory::getDataOperandBaseAddr (
161
+ builder, coordRef, /* isOptional=*/ false , loc);
162
+ llvm::SmallVector<mlir::Value> bounds = fir::factory::genImplicitBoundsOps<
163
+ mlir::omp::MapBoundsOp, mlir::omp::MapBoundsType>(
164
+ builder, info,
165
+ hlfir::translateToExtendedValue (loc, builder, hlfir::Entity{coordRef})
166
+ .first ,
167
+ /* dataExvIsAssumedSize=*/ false , loc);
168
+
169
+ mlir::omp::MapInfoOp fieldMapOp = mlir::omp::MapInfoOp::create (
170
+ builder, loc, coordRef.getType (), coordRef,
171
+ mlir::TypeAttr::get (fir::unwrapRefType (coordRef.getType ())),
172
+ op.getMapTypeAttr (),
173
+ builder.getAttr <mlir::omp::VariableCaptureKindAttr>(
174
+ mlir::omp::VariableCaptureKind::ByRef),
175
+ /* varPtrPtr=*/ mlir::Value{}, /* members=*/ mlir::ValueRange{},
176
+ /* members_index=*/ mlir::ArrayAttr{}, bounds,
177
+ /* mapperId=*/ mlir::FlatSymbolRefAttr (),
178
+ builder.getStringAttr (op.getNameAttr ().strref () + " ." + memberName +
179
+ " .implicit_map" ),
180
+ /* partial_map=*/ builder.getBoolAttr (false ));
181
+
182
+ newMapOpsForFields.emplace_back (fieldMapOp);
183
+ newMemberIndexPaths.emplace_back (indexPath.begin (), indexPath.end ());
184
+ }
185
+
78
186
// / getMemberUserList gathers all users of a particular MapInfoOp that are
79
187
// / other MapInfoOp's and places them into the mapMemberUsers list, which
80
188
// / records the map that the current argument MapInfoOp "op" is part of
@@ -363,7 +471,7 @@ class MapInfoFinalizationPass
363
471
mlir::ArrayAttr newMembersAttr;
364
472
mlir::SmallVector<mlir::Value> newMembers;
365
473
llvm::SmallVector<llvm::SmallVector<int64_t >> memberIndices;
366
- bool IsHasDeviceAddr = isHasDeviceAddr (op, target);
474
+ bool isHasDeviceAddrFlag = isHasDeviceAddr (op, target);
367
475
368
476
if (!mapMemberUsers.empty () || !op.getMembers ().empty ())
369
477
getMemberIndicesAsVectors (
@@ -406,7 +514,7 @@ class MapInfoFinalizationPass
406
514
mapUser.parent .getMembersMutable ().assign (newMemberOps);
407
515
mapUser.parent .setMembersIndexAttr (
408
516
builder.create2DI64ArrayAttr (memberIndices));
409
- } else if (!IsHasDeviceAddr ) {
517
+ } else if (!isHasDeviceAddrFlag ) {
410
518
auto baseAddr =
411
519
genBaseAddrMap (descriptor, op.getBounds (), op.getMapType (), builder);
412
520
newMembers.push_back (baseAddr);
@@ -429,7 +537,7 @@ class MapInfoFinalizationPass
429
537
// The contents of the descriptor (the base address in particular) will
430
538
// remain unchanged though.
431
539
uint64_t mapType = op.getMapType ();
432
- if (IsHasDeviceAddr ) {
540
+ if (isHasDeviceAddrFlag ) {
433
541
mapType |= llvm::to_underlying (
434
542
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS);
435
543
}
@@ -701,105 +809,143 @@ class MapInfoFinalizationPass
701
809
702
810
auto recordType = mlir::cast<fir::RecordType>(underlyingType);
703
811
llvm::SmallVector<mlir::Value> newMapOpsForFields;
704
- llvm::SmallVector<int64_t > fieldIndicies ;
812
+ llvm::SmallVector<llvm::SmallVector< int64_t >> newMemberIndexPaths ;
705
813
814
+ // 1) Handle direct top-level allocatable fields.
706
815
for (auto fieldMemTyPair : recordType.getTypeList ()) {
707
816
auto &field = fieldMemTyPair.first ;
708
817
auto memTy = fieldMemTyPair.second ;
709
818
710
- bool shouldMapField =
711
- llvm::find_if (mapVarForwardSlice, [&](mlir::Operation *sliceOp) {
712
- if (!fir::isAllocatableType (memTy))
713
- return false ;
714
-
715
- auto designateOp = mlir::dyn_cast<hlfir::DesignateOp>(sliceOp);
716
- if (!designateOp)
717
- return false ;
718
-
719
- return designateOp.getComponent () &&
720
- designateOp.getComponent ()->strref () == field;
721
- }) != mapVarForwardSlice.end ();
722
-
723
- // TODO Handle recursive record types. Adapting
724
- // `createParentSymAndGenIntermediateMaps` to work direclty on MLIR
725
- // entities might be helpful here.
726
-
727
- if (!shouldMapField)
819
+ if (!fir::isAllocatableType (memTy))
728
820
continue ;
729
821
730
- int32_t fieldIdx = recordType.getFieldIndex (field);
731
- bool alreadyMapped = [&]() {
732
- if (op.getMembersIndexAttr ())
733
- for (auto indexList : op.getMembersIndexAttr ()) {
734
- auto indexListAttr = mlir::cast<mlir::ArrayAttr>(indexList);
735
- if (indexListAttr.size () == 1 &&
736
- mlir::cast<mlir::IntegerAttr>(indexListAttr[0 ]).getInt () ==
737
- fieldIdx)
738
- return true ;
739
- }
740
-
741
- return false ;
742
- }();
743
-
744
- if (alreadyMapped)
822
+ bool referenced = llvm::any_of (mapVarForwardSlice, [&](auto *opv) {
823
+ auto designateOp = mlir::dyn_cast<hlfir::DesignateOp>(opv);
824
+ return designateOp && designateOp.getComponent () &&
825
+ designateOp.getComponent ()->strref () == field;
826
+ });
827
+ if (!referenced)
745
828
continue ;
746
829
830
+ int32_t fieldIdx = recordType.getFieldIndex (field);
747
831
builder.setInsertionPoint (op);
748
832
fir::IntOrValue idxConst =
749
833
mlir::IntegerAttr::get (builder.getI32Type (), fieldIdx);
750
834
auto fieldCoord = fir::CoordinateOp::create (
751
835
builder, op.getLoc (), builder.getRefType (memTy), op.getVarPtr (),
752
836
llvm::SmallVector<fir::IntOrValue, 1 >{idxConst});
753
- fir::factory::AddrAndBoundsInfo info =
754
- fir::factory::getDataOperandBaseAddr (
755
- builder, fieldCoord, /* isOptional=*/ false , op.getLoc ());
756
- llvm::SmallVector<mlir::Value> bounds =
757
- fir::factory::genImplicitBoundsOps<mlir::omp::MapBoundsOp,
758
- mlir::omp::MapBoundsType>(
759
- builder, info,
760
- hlfir::translateToExtendedValue (op.getLoc (), builder,
761
- hlfir::Entity{fieldCoord})
762
- .first ,
763
- /* dataExvIsAssumedSize=*/ false , op.getLoc ());
764
-
765
- mlir::omp::MapInfoOp fieldMapOp = mlir::omp::MapInfoOp::create (
766
- builder, op.getLoc (), fieldCoord.getResult ().getType (),
767
- fieldCoord.getResult (),
768
- mlir::TypeAttr::get (
769
- fir::unwrapRefType (fieldCoord.getResult ().getType ())),
770
- op.getMapTypeAttr (),
771
- builder.getAttr <mlir::omp::VariableCaptureKindAttr>(
772
- mlir::omp::VariableCaptureKind::ByRef),
773
- /* varPtrPtr=*/ mlir::Value{}, /* members=*/ mlir::ValueRange{},
774
- /* members_index=*/ mlir::ArrayAttr{}, bounds,
775
- /* mapperId=*/ mlir::FlatSymbolRefAttr (),
776
- builder.getStringAttr (op.getNameAttr ().strref () + " ." + field +
777
- " .implicit_map" ),
778
- /* partial_map=*/ builder.getBoolAttr (false ));
779
- newMapOpsForFields.emplace_back (fieldMapOp);
780
- fieldIndicies.emplace_back (fieldIdx);
837
+ int64_t fieldIdx64 = static_cast <int64_t >(fieldIdx);
838
+ llvm::SmallVector<int64_t , 1 > idxPath{fieldIdx64};
839
+ appendMemberMapIfNew (op, builder, op.getLoc (), fieldCoord, idxPath,
840
+ field, newMapOpsForFields, newMemberIndexPaths);
841
+ }
842
+
843
+ // Handle nested allocatable fields along any component chain
844
+ // referenced in the region via HLFIR designates.
845
+ llvm::SmallVector<llvm::SmallVector<int64_t >> seenIndexPaths;
846
+ for (mlir::Operation *sliceOp : mapVarForwardSlice) {
847
+ auto designateOp = mlir::dyn_cast<hlfir::DesignateOp>(sliceOp);
848
+ if (!designateOp || !designateOp.getComponent ())
849
+ continue ;
850
+ llvm::SmallVector<llvm::StringRef> compPathReversed;
851
+ compPathReversed.push_back (designateOp.getComponent ()->strref ());
852
+ mlir::Value curBase = designateOp.getMemref ();
853
+ bool rootedAtMapArg = false ;
854
+ while (true ) {
855
+ if (auto parentDes = curBase.getDefiningOp <hlfir::DesignateOp>()) {
856
+ if (!parentDes.getComponent ())
857
+ break ;
858
+ compPathReversed.push_back (parentDes.getComponent ()->strref ());
859
+ curBase = parentDes.getMemref ();
860
+ continue ;
861
+ }
862
+ if (auto decl = curBase.getDefiningOp <hlfir::DeclareOp>()) {
863
+ if (auto barg =
864
+ mlir::dyn_cast<mlir::BlockArgument>(decl.getMemref ()))
865
+ rootedAtMapArg = (barg == opBlockArg);
866
+ } else if (auto blockArg =
867
+ mlir::dyn_cast_or_null<mlir::BlockArgument>(
868
+ curBase)) {
869
+ rootedAtMapArg = (blockArg == opBlockArg);
870
+ }
871
+ break ;
872
+ }
873
+ // Only process nested paths (2+ components). Single-component paths
874
+ // for direct fields are handled above.
875
+ if (!rootedAtMapArg || compPathReversed.size () < 2 )
876
+ continue ;
877
+ builder.setInsertionPoint (op);
878
+ llvm::SmallVector<int64_t > indexPath;
879
+ mlir::Type curTy = underlyingType;
880
+ mlir::Value coordRef = op.getVarPtr ();
881
+ bool validPath = true ;
882
+ for (llvm::StringRef compName : llvm::reverse (compPathReversed)) {
883
+ auto recTy = mlir::dyn_cast<fir::RecordType>(curTy);
884
+ if (!recTy) {
885
+ validPath = false ;
886
+ break ;
887
+ }
888
+ int32_t idx = recTy.getFieldIndex (compName);
889
+ if (idx < 0 ) {
890
+ validPath = false ;
891
+ break ;
892
+ }
893
+ indexPath.push_back (idx);
894
+ mlir::Type memTy = recTy.getType (idx);
895
+ fir::IntOrValue idxConst =
896
+ mlir::IntegerAttr::get (builder.getI32Type (), idx);
897
+ coordRef = fir::CoordinateOp::create (
898
+ builder, op.getLoc (), builder.getRefType (memTy), coordRef,
899
+ llvm::SmallVector<fir::IntOrValue, 1 >{idxConst});
900
+ curTy = memTy;
901
+ }
902
+ if (!validPath)
903
+ continue ;
904
+ if (auto finalRefTy =
905
+ mlir::dyn_cast<fir::ReferenceType>(coordRef.getType ())) {
906
+ mlir::Type eleTy = finalRefTy.getElementType ();
907
+ if (fir::isAllocatableType (eleTy)) {
908
+ if (!containsPath (seenIndexPaths, indexPath)) {
909
+ seenIndexPaths.emplace_back (indexPath.begin (), indexPath.end ());
910
+ appendMemberMapIfNew (op, builder, op.getLoc (), coordRef,
911
+ indexPath, compPathReversed.front (),
912
+ newMapOpsForFields, newMemberIndexPaths);
913
+ }
914
+ }
915
+ }
781
916
}
782
917
783
918
if (newMapOpsForFields.empty ())
784
919
return mlir::WalkResult::advance ();
785
920
786
- op.getMembersMutable ().append (newMapOpsForFields);
921
+ // Deduplicate by index path to avoid emitting duplicate members for
922
+ // the same component. Use a set-based key to keep this near O(n).
923
+ llvm::SmallVector<mlir::Value> dedupMapOps;
924
+ llvm::SmallVector<llvm::SmallVector<int64_t >> dedupIndexPaths;
925
+ llvm::StringSet<> seenKeys;
926
+ for (auto [i, mapOp] : llvm::enumerate (newMapOpsForFields)) {
927
+ const auto &path = newMemberIndexPaths[i];
928
+ llvm::SmallString<64 > key;
929
+ buildPathKey (path, key);
930
+ if (seenKeys.contains (key))
931
+ continue ;
932
+ seenKeys.insert (key);
933
+ dedupMapOps.push_back (mapOp);
934
+ dedupIndexPaths.emplace_back (path.begin (), path.end ());
935
+ }
936
+ op.getMembersMutable ().append (dedupMapOps);
787
937
llvm::SmallVector<llvm::SmallVector<int64_t >> newMemberIndices;
788
- mlir::ArrayAttr oldMembersIdxAttr = op.getMembersIndexAttr ();
789
-
790
- if (oldMembersIdxAttr)
791
- for (mlir::Attribute indexList : oldMembersIdxAttr) {
938
+ if (mlir::ArrayAttr oldAttr = op.getMembersIndexAttr ())
939
+ for (mlir::Attribute indexList : oldAttr) {
792
940
llvm::SmallVector<int64_t > listVec;
793
941
794
942
for (mlir::Attribute index : mlir::cast<mlir::ArrayAttr>(indexList))
795
943
listVec.push_back (mlir::cast<mlir::IntegerAttr>(index).getInt ());
796
944
797
945
newMemberIndices.emplace_back (std::move (listVec));
798
946
}
799
-
800
- for (int64_t newFieldIdx : fieldIndicies)
801
- newMemberIndices.emplace_back (
802
- llvm::SmallVector<int64_t >(1 , newFieldIdx));
947
+ for (auto &path : dedupIndexPaths)
948
+ newMemberIndices.emplace_back (path);
803
949
804
950
op.setMembersIndexAttr (builder.create2DI64ArrayAttr (newMemberIndices));
805
951
op.setPartialMap (true );
0 commit comments