@@ -2874,39 +2874,61 @@ static int getMapDataMemberIdx(MapInfoData &mapData, omp::MapInfoOp memberOp) {
28742874 return std::distance (mapData.MapClause .begin (), res);
28752875}
28762876
2877- static omp::MapInfoOp getFirstOrLastMappedMemberPtr (omp::MapInfoOp mapInfo,
2878- bool first) {
2879- ArrayAttr indexAttr = mapInfo.getMembersIndexAttr ();
2880- // Only 1 member has been mapped, we can return it.
2881- if (indexAttr.size () == 1 )
2882- return cast<omp::MapInfoOp>(mapInfo.getMembers ()[0 ].getDefiningOp ());
2877+ static void sortMapIndices (llvm::SmallVector<size_t > &indices,
2878+ mlir::omp::MapInfoOp mapInfo,
2879+ bool ascending = true ) {
2880+ mlir::ArrayAttr indexAttr = mapInfo.getMembersIndexAttr ();
2881+ if (indexAttr.empty () || indexAttr.size () == 1 || indices.empty () ||
2882+ indices.size () == 1 )
2883+ return ;
28832884
2884- llvm::SmallVector<size_t > indices (indexAttr.size ());
2885- std::iota (indices.begin (), indices.end (), 0 );
2885+ llvm::sort (
2886+ indices.begin (), indices.end (), [&](const size_t a, const size_t b) {
2887+ auto memberIndicesA = mlir::cast<mlir::ArrayAttr>(indexAttr[a]);
2888+ auto memberIndicesB = mlir::cast<mlir::ArrayAttr>(indexAttr[b]);
2889+
2890+ size_t smallestMember = memberIndicesA.size () < memberIndicesB.size ()
2891+ ? memberIndicesA.size ()
2892+ : memberIndicesB.size ();
28862893
2887- llvm::sort (indices. begin (), indices. end (),
2888- [&]( const size_t a, const size_t b) {
2889- auto memberIndicesA = cast<ArrayAttr>(indexAttr[a]);
2890- auto memberIndicesB = cast<ArrayAttr>(indexAttr[b] );
2891- for ( const auto it : llvm::zip (memberIndicesA, memberIndicesB)) {
2892- int64_t aIndex = cast<IntegerAttr>(std::get< 0 >(it)). getInt ();
2893- int64_t bIndex = cast<IntegerAttr>(std::get< 1 >(it)) .getInt ();
2894+ for ( size_t i = 0 ; i < smallestMember; ++i) {
2895+ int64_t aIndex =
2896+ mlir:: cast<mlir::IntegerAttr>(memberIndicesA. getValue ()[i])
2897+ . getInt ( );
2898+ int64_t bIndex =
2899+ mlir:: cast<mlir::IntegerAttr>(memberIndicesB. getValue ()[i])
2900+ .getInt ();
28942901
2895- if (aIndex == bIndex)
2896- continue ;
2902+ if (aIndex == bIndex)
2903+ continue ;
28972904
2898- if (aIndex < bIndex)
2899- return first ;
2905+ if (aIndex < bIndex)
2906+ return ascending ;
29002907
2901- if (aIndex > bIndex)
2902- return !first ;
2903- }
2908+ if (aIndex > bIndex)
2909+ return !ascending ;
2910+ }
29042911
2905- // Iterated the up until the end of the smallest member and
2906- // they were found to be equal up to that point, so select
2907- // the member with the lowest index count, so the "parent"
2908- return memberIndicesA.size () < memberIndicesB.size ();
2909- });
2912+ // Iterated up until the end of the smallest member and
2913+ // they were found to be equal up to that point, so select
2914+ // the member with the lowest index count, so the "parent"
2915+ return memberIndicesA.size () < memberIndicesB.size ();
2916+ });
2917+ }
2918+
2919+ static mlir::omp::MapInfoOp
2920+ getFirstOrLastMappedMemberPtr (mlir::omp::MapInfoOp mapInfo, bool first) {
2921+ mlir::ArrayAttr indexAttr = mapInfo.getMembersIndexAttr ();
2922+ // Only 1 member has been mapped, we can return it.
2923+ if (indexAttr.size () == 1 )
2924+ if (auto mapOp =
2925+ dyn_cast<omp::MapInfoOp>(mapInfo.getMembers ()[0 ].getDefiningOp ()))
2926+ return mapOp;
2927+
2928+ llvm::SmallVector<size_t > indices;
2929+ indices.resize (indexAttr.size ());
2930+ std::iota (indices.begin (), indices.end (), 0 );
2931+ sortMapIndices (indices, mapInfo, first);
29102932
29112933 return llvm::cast<omp::MapInfoOp>(
29122934 mapInfo.getMembers ()[indices.front ()].getDefiningOp ());
@@ -3005,6 +3027,91 @@ calculateBoundsOffset(LLVM::ModuleTranslation &moduleTranslation,
30053027 return idx;
30063028}
30073029
3030+ // Gathers members that are overlapping in the parent, excluding members that
3031+ // themselves overlap, keeping the top-most (closest to parents level) map.
3032+ static void getOverlappedMembers (llvm::SmallVector<size_t > &overlapMapDataIdxs,
3033+ MapInfoData &mapData,
3034+ omp::MapInfoOp parentOp) {
3035+ // No members mapped, no overlaps.
3036+ if (parentOp.getMembers ().empty ())
3037+ return ;
3038+
3039+ // Single member, we can insert and return early.
3040+ if (parentOp.getMembers ().size () == 1 ) {
3041+ overlapMapDataIdxs.push_back (0 );
3042+ return ;
3043+ }
3044+
3045+ // 1) collect list of top-level overlapping members from MemberOp
3046+ llvm::SmallVector<std::pair<int , mlir::ArrayAttr>> memberByIndex;
3047+ mlir::ArrayAttr indexAttr = parentOp.getMembersIndexAttr ();
3048+ for (auto [memIndex, indicesAttr] : llvm::enumerate (indexAttr))
3049+ memberByIndex.push_back (
3050+ std::make_pair (memIndex, mlir::cast<mlir::ArrayAttr>(indicesAttr)));
3051+
3052+ // Sort the smallest first (higher up the parent -> member chain), so that
3053+ // when we remove members, we remove as much as we can in the initial
3054+ // iterations, shortening the number of passes required.
3055+ llvm::sort (memberByIndex.begin (), memberByIndex.end (),
3056+ [&](auto a, auto b) { return a.second .size () < b.second .size (); });
3057+
3058+ auto getAsIntegers = [](mlir::ArrayAttr values) {
3059+ llvm::SmallVector<int64_t > ints;
3060+ ints.reserve (values.size ());
3061+ llvm::transform (values, std::back_inserter (ints),
3062+ [](mlir::Attribute value) {
3063+ return mlir::cast<mlir::IntegerAttr>(value).getInt ();
3064+ });
3065+ return ints;
3066+ };
3067+
3068+ // Remove elements from the vector if there is a parent element that
3069+ // supersedes it. i.e. if member [0] is mapped, we can remove members [0,1],
3070+ // [0,2].. etc.
3071+ for (auto v : make_early_inc_range (memberByIndex)) {
3072+ auto vArr = getAsIntegers (v.second );
3073+ memberByIndex.erase (
3074+ std::remove_if (memberByIndex.begin (), memberByIndex.end (),
3075+ [&](auto x) {
3076+ if (v == x)
3077+ return false ;
3078+
3079+ auto xArr = getAsIntegers (x.second );
3080+ return std::equal (vArr.begin (), vArr.end (),
3081+ xArr.begin ()) &&
3082+ xArr.size () >= vArr.size ();
3083+ }),
3084+ memberByIndex.end ());
3085+ }
3086+
3087+ // Collect the indices from mapData that we need, as we technically need the
3088+ // base pointer etc. info, which is stored in there and primarily accessible
3089+ // via index at the moment.
3090+ for (auto v : memberByIndex)
3091+ overlapMapDataIdxs.push_back (v.first );
3092+ }
3093+
3094+ // The intent is to verify if the mapped data being passed is a
3095+ // pointer -> pointee that requires special handling in certain cases,
3096+ // e.g. applying the OMP_MAP_PTR_AND_OBJ map type.
3097+ //
3098+ // There may be a better way to verify this, but unfortunately with
3099+ // opaque pointers we lose the ability to easily check if something is
3100+ // a pointer whilst maintaining access to the underlying type.
3101+ static bool checkIfPointerMap (omp::MapInfoOp mapOp) {
3102+ // If we have a varPtrPtr field assigned then the underlying type is a pointer
3103+ if (mapOp.getVarPtrPtr ())
3104+ return true ;
3105+
3106+ // If the map data is declare target with a link clause, then it's represented
3107+ // as a pointer when we lower it to LLVM-IR even if at the MLIR level it has
3108+ // no relation to pointers.
3109+ if (isDeclareTargetLink (mapOp.getVarPtr ()))
3110+ return true ;
3111+
3112+ return false ;
3113+ }
3114+
30083115// This creates two insertions into the MapInfosTy data structure for the
30093116// "parent" of a set of members, (usually a container e.g.
30103117// class/structure/derived type) when subsequent members have also been
@@ -3045,7 +3152,6 @@ static llvm::omp::OpenMPOffloadMappingFlags mapParentWithMembers(
30453152 // runtime information on the dynamically allocated data).
30463153 auto parentClause =
30473154 llvm::cast<omp::MapInfoOp>(mapData.MapClause [mapDataIndex]);
3048-
30493155 llvm::Value *lowAddr, *highAddr;
30503156 if (!parentClause.getPartialMap ()) {
30513157 lowAddr = builder.CreatePointerCast (mapData.Pointers [mapDataIndex],
@@ -3092,37 +3198,77 @@ static llvm::omp::OpenMPOffloadMappingFlags mapParentWithMembers(
30923198 // what we support as expected.
30933199 llvm::omp::OpenMPOffloadMappingFlags mapFlag = mapData.Types [mapDataIndex];
30943200 ompBuilder.setCorrectMemberOfFlag (mapFlag, memberOfFlag);
3095- combinedInfo.Types .emplace_back (mapFlag);
3096- combinedInfo.DevicePointers .emplace_back (
3097- llvm::OpenMPIRBuilder::DeviceInfoTy::None);
3098- combinedInfo.Names .emplace_back (LLVM::createMappingInformation (
3099- mapData.MapClause [mapDataIndex]->getLoc (), ompBuilder));
3100- combinedInfo.BasePointers .emplace_back (mapData.BasePointers [mapDataIndex]);
3101- combinedInfo.Pointers .emplace_back (mapData.Pointers [mapDataIndex]);
3102- combinedInfo.Sizes .emplace_back (mapData.Sizes [mapDataIndex]);
3103- }
3104- return memberOfFlag;
3105- }
3106-
3107- // The intent is to verify if the mapped data being passed is a
3108- // pointer -> pointee that requires special handling in certain cases,
3109- // e.g. applying the OMP_MAP_PTR_AND_OBJ map type.
3110- //
3111- // There may be a better way to verify this, but unfortunately with
3112- // opaque pointers we lose the ability to easily check if something is
3113- // a pointer whilst maintaining access to the underlying type.
3114- static bool checkIfPointerMap (omp::MapInfoOp mapOp) {
3115- // If we have a varPtrPtr field assigned then the underlying type is a pointer
3116- if (mapOp.getVarPtrPtr ())
3117- return true ;
31183201
3119- // If the map data is declare target with a link clause, then it's represented
3120- // as a pointer when we lower it to LLVM-IR even if at the MLIR level it has
3121- // no relation to pointers.
3122- if (isDeclareTargetLink (mapOp.getVarPtr ()))
3123- return true ;
3202+ if (targetDirective == TargetDirective::TargetUpdate) {
3203+ combinedInfo.Types .emplace_back (mapFlag);
3204+ combinedInfo.DevicePointers .emplace_back (
3205+ mapData.DevicePointers [mapDataIndex]);
3206+ combinedInfo.Names .emplace_back (LLVM::createMappingInformation (
3207+ mapData.MapClause [mapDataIndex]->getLoc (), ompBuilder));
3208+ combinedInfo.BasePointers .emplace_back (
3209+ mapData.BasePointers [mapDataIndex]);
3210+ combinedInfo.Pointers .emplace_back (mapData.Pointers [mapDataIndex]);
3211+ combinedInfo.Sizes .emplace_back (mapData.Sizes [mapDataIndex]);
3212+ } else {
3213+ llvm::SmallVector<size_t > overlapIdxs;
3214+ // Find all of the members that "overlap", i.e. occlude other members that
3215+ // were mapped alongside the parent, e.g. member [0], occludes
3216+ getOverlappedMembers (overlapIdxs, mapData, parentClause);
3217+ // We need to make sure the overlapped members are sorted in order of
3218+ // lowest address to highest address
3219+ sortMapIndices (overlapIdxs, parentClause);
3220+
3221+ lowAddr = builder.CreatePointerCast (mapData.Pointers [mapDataIndex],
3222+ builder.getPtrTy ());
3223+ highAddr = builder.CreatePointerCast (
3224+ builder.CreateConstGEP1_32 (mapData.BaseType [mapDataIndex],
3225+ mapData.Pointers [mapDataIndex], 1 ),
3226+ builder.getPtrTy ());
3227+
3228+ // TODO: We may want to skip arrays/array sections in this as Clang does
3229+ // so it appears to be an optimisation rather than a neccessity though,
3230+ // but this requires further investigation. However, we would have to make
3231+ // sure to not exclude maps with bounds that ARE pointers, as these are
3232+ // processed as seperate components, i.e. pointer + data.
3233+ for (auto v : overlapIdxs) {
3234+ auto mapDataOverlapIdx = getMapDataMemberIdx (
3235+ mapData,
3236+ cast<omp::MapInfoOp>(parentClause.getMembers ()[v].getDefiningOp ()));
3237+ combinedInfo.Types .emplace_back (mapFlag);
3238+ combinedInfo.DevicePointers .emplace_back (
3239+ mapData.DevicePointers [mapDataOverlapIdx]);
3240+ combinedInfo.Names .emplace_back (LLVM::createMappingInformation (
3241+ mapData.MapClause [mapDataIndex]->getLoc (), ompBuilder));
3242+ combinedInfo.BasePointers .emplace_back (
3243+ mapData.BasePointers [mapDataIndex]);
3244+ combinedInfo.Pointers .emplace_back (lowAddr);
3245+ combinedInfo.Sizes .emplace_back (builder.CreateIntCast (
3246+ builder.CreatePtrDiff (builder.getInt8Ty (),
3247+ mapData.OriginalValue [mapDataOverlapIdx],
3248+ lowAddr),
3249+ builder.getInt64Ty (), /* isSigned=*/ true ));
3250+ lowAddr = builder.CreateConstGEP1_32 (
3251+ checkIfPointerMap (llvm::cast<omp::MapInfoOp>(
3252+ mapData.MapClause [mapDataOverlapIdx]))
3253+ ? builder.getPtrTy ()
3254+ : mapData.BaseType [mapDataOverlapIdx],
3255+ mapData.BasePointers [mapDataOverlapIdx], 1 );
3256+ }
31243257
3125- return false ;
3258+ combinedInfo.Types .emplace_back (mapFlag);
3259+ combinedInfo.DevicePointers .emplace_back (
3260+ mapData.DevicePointers [mapDataIndex]);
3261+ combinedInfo.Names .emplace_back (LLVM::createMappingInformation (
3262+ mapData.MapClause [mapDataIndex]->getLoc (), ompBuilder));
3263+ combinedInfo.BasePointers .emplace_back (
3264+ mapData.BasePointers [mapDataIndex]);
3265+ combinedInfo.Pointers .emplace_back (lowAddr);
3266+ combinedInfo.Sizes .emplace_back (builder.CreateIntCast (
3267+ builder.CreatePtrDiff (builder.getInt8Ty (), highAddr, lowAddr),
3268+ builder.getInt64Ty (), true ));
3269+ }
3270+ }
3271+ return memberOfFlag;
31263272}
31273273
31283274// This function is intended to add explicit mappings of members
0 commit comments