@@ -2979,39 +2979,61 @@ static int getMapDataMemberIdx(MapInfoData &mapData, omp::MapInfoOp memberOp) {
29792979 return std::distance (mapData.MapClause .begin (), res);
29802980}
29812981
2982- static omp::MapInfoOp getFirstOrLastMappedMemberPtr (omp::MapInfoOp mapInfo,
2983- bool first) {
2984- ArrayAttr indexAttr = mapInfo.getMembersIndexAttr ();
2985- // Only 1 member has been mapped, we can return it.
2986- if (indexAttr.size () == 1 )
2987- return cast<omp::MapInfoOp>(mapInfo.getMembers ()[0 ].getDefiningOp ());
2982+ static void sortMapIndices (llvm::SmallVector<size_t > &indices,
2983+ mlir::omp::MapInfoOp mapInfo,
2984+ bool ascending = true ) {
2985+ mlir::ArrayAttr indexAttr = mapInfo.getMembersIndexAttr ();
2986+ if (indexAttr.empty () || indexAttr.size () == 1 || indices.empty () ||
2987+ indices.size () == 1 )
2988+ return ;
29882989
2989- llvm::SmallVector<size_t > indices (indexAttr.size ());
2990- std::iota (indices.begin (), indices.end (), 0 );
2990+ llvm::sort (
2991+ indices.begin (), indices.end (), [&](const size_t a, const size_t b) {
2992+ auto memberIndicesA = mlir::cast<mlir::ArrayAttr>(indexAttr[a]);
2993+ auto memberIndicesB = mlir::cast<mlir::ArrayAttr>(indexAttr[b]);
2994+
2995+ size_t smallestMember = memberIndicesA.size () < memberIndicesB.size ()
2996+ ? memberIndicesA.size ()
2997+ : memberIndicesB.size ();
29912998
2992- llvm::sort (indices. begin (), indices. end (),
2993- [&]( const size_t a, const size_t b) {
2994- auto memberIndicesA = cast<ArrayAttr>(indexAttr[a]);
2995- auto memberIndicesB = cast<ArrayAttr>(indexAttr[b] );
2996- for ( const auto it : llvm::zip (memberIndicesA, memberIndicesB)) {
2997- int64_t aIndex = cast<IntegerAttr>(std::get< 0 >(it)). getInt ();
2998- int64_t bIndex = cast<IntegerAttr>(std::get< 1 >(it)) .getInt ();
2999+ for ( size_t i = 0 ; i < smallestMember; ++i) {
3000+ int64_t aIndex =
3001+ mlir:: cast<mlir::IntegerAttr>(memberIndicesA. getValue ()[i])
3002+ . getInt ( );
3003+ int64_t bIndex =
3004+ mlir:: cast<mlir::IntegerAttr>(memberIndicesB. getValue ()[i])
3005+ .getInt ();
29993006
3000- if (aIndex == bIndex)
3001- continue ;
3007+ if (aIndex == bIndex)
3008+ continue ;
30023009
3003- if (aIndex < bIndex)
3004- return first ;
3010+ if (aIndex < bIndex)
3011+ return ascending ;
30053012
3006- if (aIndex > bIndex)
3007- return !first ;
3008- }
3013+ if (aIndex > bIndex)
3014+ return !ascending ;
3015+ }
30093016
3010- // Iterated the up until the end of the smallest member and
3011- // they were found to be equal up to that point, so select
3012- // the member with the lowest index count, so the "parent"
3013- return memberIndicesA.size () < memberIndicesB.size ();
3014- });
3017+ // Iterated up until the end of the smallest member and
3018+ // they were found to be equal up to that point, so select
3019+ // the member with the lowest index count, so the "parent"
3020+ return memberIndicesA.size () < memberIndicesB.size ();
3021+ });
3022+ }
3023+
3024+ static mlir::omp::MapInfoOp
3025+ getFirstOrLastMappedMemberPtr (mlir::omp::MapInfoOp mapInfo, bool first) {
3026+ mlir::ArrayAttr indexAttr = mapInfo.getMembersIndexAttr ();
3027+ // Only 1 member has been mapped, we can return it.
3028+ if (indexAttr.size () == 1 )
3029+ if (auto mapOp =
3030+ dyn_cast<omp::MapInfoOp>(mapInfo.getMembers ()[0 ].getDefiningOp ()))
3031+ return mapOp;
3032+
3033+ llvm::SmallVector<size_t > indices;
3034+ indices.resize (indexAttr.size ());
3035+ std::iota (indices.begin (), indices.end (), 0 );
3036+ sortMapIndices (indices, mapInfo, first);
30153037
30163038 return llvm::cast<omp::MapInfoOp>(
30173039 mapInfo.getMembers ()[indices.front ()].getDefiningOp ());
@@ -3110,6 +3132,91 @@ calculateBoundsOffset(LLVM::ModuleTranslation &moduleTranslation,
31103132 return idx;
31113133}
31123134
3135+ // Gathers members that are overlapping in the parent, excluding members that
3136+ // themselves overlap, keeping the top-most (closest to parents level) map.
3137+ static void getOverlappedMembers (llvm::SmallVector<size_t > &overlapMapDataIdxs,
3138+ MapInfoData &mapData,
3139+ omp::MapInfoOp parentOp) {
3140+ // No members mapped, no overlaps.
3141+ if (parentOp.getMembers ().empty ())
3142+ return ;
3143+
3144+ // Single member, we can insert and return early.
3145+ if (parentOp.getMembers ().size () == 1 ) {
3146+ overlapMapDataIdxs.push_back (0 );
3147+ return ;
3148+ }
3149+
3150+ // 1) collect list of top-level overlapping members from MemberOp
3151+ llvm::SmallVector<std::pair<int , mlir::ArrayAttr>> memberByIndex;
3152+ mlir::ArrayAttr indexAttr = parentOp.getMembersIndexAttr ();
3153+ for (auto [memIndex, indicesAttr] : llvm::enumerate (indexAttr))
3154+ memberByIndex.push_back (
3155+ std::make_pair (memIndex, mlir::cast<mlir::ArrayAttr>(indicesAttr)));
3156+
3157+ // Sort the smallest first (higher up the parent -> member chain), so that
3158+ // when we remove members, we remove as much as we can in the initial
3159+ // iterations, shortening the number of passes required.
3160+ llvm::sort (memberByIndex.begin (), memberByIndex.end (),
3161+ [&](auto a, auto b) { return a.second .size () < b.second .size (); });
3162+
3163+ auto getAsIntegers = [](mlir::ArrayAttr values) {
3164+ llvm::SmallVector<int64_t > ints;
3165+ ints.reserve (values.size ());
3166+ llvm::transform (values, std::back_inserter (ints),
3167+ [](mlir::Attribute value) {
3168+ return mlir::cast<mlir::IntegerAttr>(value).getInt ();
3169+ });
3170+ return ints;
3171+ };
3172+
3173+ // Remove elements from the vector if there is a parent element that
3174+ // supersedes it. i.e. if member [0] is mapped, we can remove members [0,1],
3175+ // [0,2].. etc.
3176+ for (auto v : make_early_inc_range (memberByIndex)) {
3177+ auto vArr = getAsIntegers (v.second );
3178+ memberByIndex.erase (
3179+ std::remove_if (memberByIndex.begin (), memberByIndex.end (),
3180+ [&](auto x) {
3181+ if (v == x)
3182+ return false ;
3183+
3184+ auto xArr = getAsIntegers (x.second );
3185+ return std::equal (vArr.begin (), vArr.end (),
3186+ xArr.begin ()) &&
3187+ xArr.size () >= vArr.size ();
3188+ }),
3189+ memberByIndex.end ());
3190+ }
3191+
3192+ // Collect the indices from mapData that we need, as we technically need the
3193+ // base pointer etc. info, which is stored in there and primarily accessible
3194+ // via index at the moment.
3195+ for (auto v : memberByIndex)
3196+ overlapMapDataIdxs.push_back (v.first );
3197+ }
3198+
3199+ // The intent is to verify if the mapped data being passed is a
3200+ // pointer -> pointee that requires special handling in certain cases,
3201+ // e.g. applying the OMP_MAP_PTR_AND_OBJ map type.
3202+ //
3203+ // There may be a better way to verify this, but unfortunately with
3204+ // opaque pointers we lose the ability to easily check if something is
3205+ // a pointer whilst maintaining access to the underlying type.
3206+ static bool checkIfPointerMap (omp::MapInfoOp mapOp) {
3207+ // If we have a varPtrPtr field assigned then the underlying type is a pointer
3208+ if (mapOp.getVarPtrPtr ())
3209+ return true ;
3210+
3211+ // If the map data is declare target with a link clause, then it's represented
3212+ // as a pointer when we lower it to LLVM-IR even if at the MLIR level it has
3213+ // no relation to pointers.
3214+ if (isDeclareTargetLink (mapOp.getVarPtr ()))
3215+ return true ;
3216+
3217+ return false ;
3218+ }
3219+
31133220// This creates two insertions into the MapInfosTy data structure for the
31143221// "parent" of a set of members, (usually a container e.g.
31153222// class/structure/derived type) when subsequent members have also been
@@ -3150,7 +3257,6 @@ static llvm::omp::OpenMPOffloadMappingFlags mapParentWithMembers(
31503257 // runtime information on the dynamically allocated data).
31513258 auto parentClause =
31523259 llvm::cast<omp::MapInfoOp>(mapData.MapClause [mapDataIndex]);
3153-
31543260 llvm::Value *lowAddr, *highAddr;
31553261 if (!parentClause.getPartialMap ()) {
31563262 lowAddr = builder.CreatePointerCast (mapData.Pointers [mapDataIndex],
@@ -3197,37 +3303,77 @@ static llvm::omp::OpenMPOffloadMappingFlags mapParentWithMembers(
31973303 // what we support as expected.
31983304 llvm::omp::OpenMPOffloadMappingFlags mapFlag = mapData.Types [mapDataIndex];
31993305 ompBuilder.setCorrectMemberOfFlag (mapFlag, memberOfFlag);
3200- combinedInfo.Types .emplace_back (mapFlag);
3201- combinedInfo.DevicePointers .emplace_back (
3202- llvm::OpenMPIRBuilder::DeviceInfoTy::None);
3203- combinedInfo.Names .emplace_back (LLVM::createMappingInformation (
3204- mapData.MapClause [mapDataIndex]->getLoc (), ompBuilder));
3205- combinedInfo.BasePointers .emplace_back (mapData.BasePointers [mapDataIndex]);
3206- combinedInfo.Pointers .emplace_back (mapData.Pointers [mapDataIndex]);
3207- combinedInfo.Sizes .emplace_back (mapData.Sizes [mapDataIndex]);
3208- }
3209- return memberOfFlag;
3210- }
3211-
3212- // The intent is to verify if the mapped data being passed is a
3213- // pointer -> pointee that requires special handling in certain cases,
3214- // e.g. applying the OMP_MAP_PTR_AND_OBJ map type.
3215- //
3216- // There may be a better way to verify this, but unfortunately with
3217- // opaque pointers we lose the ability to easily check if something is
3218- // a pointer whilst maintaining access to the underlying type.
3219- static bool checkIfPointerMap (omp::MapInfoOp mapOp) {
3220- // If we have a varPtrPtr field assigned then the underlying type is a pointer
3221- if (mapOp.getVarPtrPtr ())
3222- return true ;
32233306
3224- // If the map data is declare target with a link clause, then it's represented
3225- // as a pointer when we lower it to LLVM-IR even if at the MLIR level it has
3226- // no relation to pointers.
3227- if (isDeclareTargetLink (mapOp.getVarPtr ()))
3228- return true ;
3307+ if (targetDirective == TargetDirective::TargetUpdate) {
3308+ combinedInfo.Types .emplace_back (mapFlag);
3309+ combinedInfo.DevicePointers .emplace_back (
3310+ mapData.DevicePointers [mapDataIndex]);
3311+ combinedInfo.Names .emplace_back (LLVM::createMappingInformation (
3312+ mapData.MapClause [mapDataIndex]->getLoc (), ompBuilder));
3313+ combinedInfo.BasePointers .emplace_back (
3314+ mapData.BasePointers [mapDataIndex]);
3315+ combinedInfo.Pointers .emplace_back (mapData.Pointers [mapDataIndex]);
3316+ combinedInfo.Sizes .emplace_back (mapData.Sizes [mapDataIndex]);
3317+ } else {
3318+ llvm::SmallVector<size_t > overlapIdxs;
3319+ // Find all of the members that "overlap", i.e. occlude other members that
3320+ // were mapped alongside the parent, e.g. member [0], occludes
3321+ getOverlappedMembers (overlapIdxs, mapData, parentClause);
3322+ // We need to make sure the overlapped members are sorted in order of
3323+ // lowest address to highest address
3324+ sortMapIndices (overlapIdxs, parentClause);
3325+
3326+ lowAddr = builder.CreatePointerCast (mapData.Pointers [mapDataIndex],
3327+ builder.getPtrTy ());
3328+ highAddr = builder.CreatePointerCast (
3329+ builder.CreateConstGEP1_32 (mapData.BaseType [mapDataIndex],
3330+ mapData.Pointers [mapDataIndex], 1 ),
3331+ builder.getPtrTy ());
3332+
3333+ // TODO: We may want to skip arrays/array sections in this as Clang does
3334+ // so it appears to be an optimisation rather than a neccessity though,
3335+ // but this requires further investigation. However, we would have to make
3336+ // sure to not exclude maps with bounds that ARE pointers, as these are
3337+ // processed as seperate components, i.e. pointer + data.
3338+ for (auto v : overlapIdxs) {
3339+ auto mapDataOverlapIdx = getMapDataMemberIdx (
3340+ mapData,
3341+ cast<omp::MapInfoOp>(parentClause.getMembers ()[v].getDefiningOp ()));
3342+ combinedInfo.Types .emplace_back (mapFlag);
3343+ combinedInfo.DevicePointers .emplace_back (
3344+ mapData.DevicePointers [mapDataOverlapIdx]);
3345+ combinedInfo.Names .emplace_back (LLVM::createMappingInformation (
3346+ mapData.MapClause [mapDataIndex]->getLoc (), ompBuilder));
3347+ combinedInfo.BasePointers .emplace_back (
3348+ mapData.BasePointers [mapDataIndex]);
3349+ combinedInfo.Pointers .emplace_back (lowAddr);
3350+ combinedInfo.Sizes .emplace_back (builder.CreateIntCast (
3351+ builder.CreatePtrDiff (builder.getInt8Ty (),
3352+ mapData.OriginalValue [mapDataOverlapIdx],
3353+ lowAddr),
3354+ builder.getInt64Ty (), /* isSigned=*/ true ));
3355+ lowAddr = builder.CreateConstGEP1_32 (
3356+ checkIfPointerMap (llvm::cast<omp::MapInfoOp>(
3357+ mapData.MapClause [mapDataOverlapIdx]))
3358+ ? builder.getPtrTy ()
3359+ : mapData.BaseType [mapDataOverlapIdx],
3360+ mapData.BasePointers [mapDataOverlapIdx], 1 );
3361+ }
32293362
3230- return false ;
3363+ combinedInfo.Types .emplace_back (mapFlag);
3364+ combinedInfo.DevicePointers .emplace_back (
3365+ mapData.DevicePointers [mapDataIndex]);
3366+ combinedInfo.Names .emplace_back (LLVM::createMappingInformation (
3367+ mapData.MapClause [mapDataIndex]->getLoc (), ompBuilder));
3368+ combinedInfo.BasePointers .emplace_back (
3369+ mapData.BasePointers [mapDataIndex]);
3370+ combinedInfo.Pointers .emplace_back (lowAddr);
3371+ combinedInfo.Sizes .emplace_back (builder.CreateIntCast (
3372+ builder.CreatePtrDiff (builder.getInt8Ty (), highAddr, lowAddr),
3373+ builder.getInt64Ty (), true ));
3374+ }
3375+ }
3376+ return memberOfFlag;
32313377}
32323378
32333379// This function is intended to add explicit mappings of members
0 commit comments