@@ -75,6 +75,44 @@ class MapInfoFinalizationPass
7575 // / | |
7676 std::map<mlir::Operation *, mlir::Value> localBoxAllocas;
7777
78+ // List of deferrable descriptors to process at the end of
79+ // the pass.
80+ llvm::SmallVector<mlir::Operation *> deferrableDesc;
81+
82+ // Check if the declaration operation we have refers to a dummy
83+ // function argument.
84+ bool isDummyArgument (mlir::Value mappedValue) {
85+ if (auto declareOp = mlir::dyn_cast_if_present<hlfir::DeclareOp>(
86+ mappedValue.getDefiningOp ()))
87+ if (auto dummyScope = declareOp.getDummyScope ())
88+ return true ;
89+ return false ;
90+ }
91+
92+ // Relevant for OpenMP < 5.2, where attach semantics and rules don't exist.
93+ // As descriptors were an unspoken implementation detail in these versions
94+ // there's certain cases where the user (and the compiler implementation)
95+ // can create data mapping errors by having temporary descriptors stuck
96+ // in memory. The main example is calling an 'target enter data map'
97+ // without a corresponding exit on an assumed shape or size dummy
98+ // argument, a local stack descriptor is generated, gets mapped and
99+ // is then left on device. A user doesn't realize what they've done as
100+ // the OpenMP specification isn't explicit on descriptor handling in
101+ // earlier versions and as far as Fortran is concerned this si something
102+ // hidden from a user. To avoid this we can defer the descriptor mapping
103+ // in these cases until target or target data regions, when we can be
104+ // sure they have a clear limited scope on device.
105+ bool canDeferDescriptorMapping (mlir::Value descriptor) {
106+ if (fir::isAllocatableType (descriptor.getType ()) ||
107+ fir::isPointerType (descriptor.getType ()))
108+ return false ;
109+ if (isDummyArgument (descriptor) &&
110+ (fir::isAssumedType (descriptor.getType ()) ||
111+ fir::isAssumedShape (descriptor.getType ())))
112+ return true ;
113+ return false ;
114+ }
115+
78116 // / getMemberUserList gathers all users of a particular MapInfoOp that are
79117 // / other MapInfoOp's and places them into the mapMemberUsers list, which
80118 // / records the map that the current argument MapInfoOp "op" is part of
@@ -126,13 +164,16 @@ class MapInfoFinalizationPass
126164 // / fir::BoxOffsetOp we utilise to access the descriptor datas
127165 // / base address can be utilised.
128166 mlir::Value getDescriptorFromBoxMap (mlir::omp::MapInfoOp boxMap,
129- fir::FirOpBuilder &builder) {
167+ fir::FirOpBuilder &builder,
168+ bool &canDescBeDeferred) {
130169 mlir::Value descriptor = boxMap.getVarPtr ();
131170 if (!fir::isTypeWithDescriptor (boxMap.getVarType ()))
132171 if (auto addrOp = mlir::dyn_cast_if_present<fir::BoxAddrOp>(
133172 boxMap.getVarPtr ().getDefiningOp ()))
134173 descriptor = addrOp.getVal ();
135174
175+ canDescBeDeferred = canDeferDescriptorMapping (descriptor);
176+
136177 if (!mlir::isa<fir::BaseBoxType>(descriptor.getType ()) &&
137178 !fir::factory::isOptionalArgument (descriptor.getDefiningOp ()))
138179 return descriptor;
@@ -283,8 +324,7 @@ class MapInfoFinalizationPass
283324
284325 // / Check if the mapOp is present in the HasDeviceAddr clause on
285326 // / the userOp. Only applies to TargetOp.
286- bool isHasDeviceAddr (mlir::omp::MapInfoOp mapOp, mlir::Operation *userOp) {
287- assert (userOp && " Expecting non-null argument" );
327+ bool isHasDeviceAddr (mlir::omp::MapInfoOp mapOp, mlir::Operation &userOp) {
288328 if (auto targetOp = llvm::dyn_cast<mlir::omp::TargetOp>(userOp)) {
289329 for (mlir::Value hda : targetOp.getHasDeviceAddrVars ()) {
290330 if (hda.getDefiningOp () == mapOp)
@@ -294,6 +334,26 @@ class MapInfoFinalizationPass
294334 return false ;
295335 }
296336
337+ bool isUseDeviceAddr (mlir::omp::MapInfoOp mapOp, mlir::Operation &userOp) {
338+ if (auto targetDataOp = llvm::dyn_cast<mlir::omp::TargetDataOp>(userOp)) {
339+ for (mlir::Value uda : targetDataOp.getUseDeviceAddrVars ()) {
340+ if (uda.getDefiningOp () == mapOp)
341+ return true ;
342+ }
343+ }
344+ return false ;
345+ }
346+
347+ bool isUseDevicePtr (mlir::omp::MapInfoOp mapOp, mlir::Operation &userOp) {
348+ if (auto targetDataOp = llvm::dyn_cast<mlir::omp::TargetDataOp>(userOp)) {
349+ for (mlir::Value udp : targetDataOp.getUseDevicePtrVars ()) {
350+ if (udp.getDefiningOp () == mapOp)
351+ return true ;
352+ }
353+ }
354+ return false ;
355+ }
356+
297357 mlir::omp::MapInfoOp genBoxcharMemberMap (mlir::omp::MapInfoOp op,
298358 fir::FirOpBuilder &builder) {
299359 if (!op.getMembers ().empty ())
@@ -358,12 +418,14 @@ class MapInfoFinalizationPass
358418
359419 // TODO: map the addendum segment of the descriptor, similarly to the
360420 // base address/data pointer member.
361- mlir::Value descriptor = getDescriptorFromBoxMap (op, builder);
421+ bool descCanBeDeferred = false ;
422+ mlir::Value descriptor =
423+ getDescriptorFromBoxMap (op, builder, descCanBeDeferred);
362424
363425 mlir::ArrayAttr newMembersAttr;
364426 mlir::SmallVector<mlir::Value> newMembers;
365427 llvm::SmallVector<llvm::SmallVector<int64_t >> memberIndices;
366- bool IsHasDeviceAddr = isHasDeviceAddr (op, target);
428+ bool IsHasDeviceAddr = isHasDeviceAddr (op, * target);
367429
368430 if (!mapMemberUsers.empty () || !op.getMembers ().empty ())
369431 getMemberIndicesAsVectors (
@@ -445,6 +507,10 @@ class MapInfoFinalizationPass
445507 /* partial_map=*/ builder.getBoolAttr (false ));
446508 op.replaceAllUsesWith (newDescParentMapOp.getResult ());
447509 op->erase ();
510+
511+ if (descCanBeDeferred)
512+ deferrableDesc.push_back (newDescParentMapOp);
513+
448514 return newDescParentMapOp;
449515 }
450516
@@ -593,6 +659,124 @@ class MapInfoFinalizationPass
593659 return nullptr ;
594660 }
595661
662+ void addImplicitDescriptorMapToTargetDataOp (mlir::omp::MapInfoOp op,
663+ fir::FirOpBuilder &builder,
664+ mlir::Operation &target) {
665+ // Checks if the map is present as an explicit map already on the target
666+ // data directive, and not just present on a use_device_addr/ptr, as if
667+ // that's the case, we should not need to add an implicit map for the
668+ // descriptor.
669+ auto explicitMappingPresent = [](mlir::omp::MapInfoOp op,
670+ mlir::omp::TargetDataOp tarData) {
671+ // Verify top-level descriptor mapping is at least equal with same
672+ // varPtr, the map type should always be To for a descriptor, which is
673+ // all we really care about for this mapping as we aim to make sure the
674+ // descriptor is always present on device if we're expecting to access
675+ // the underlying data.
676+ if (tarData.getMapVars ().empty ())
677+ return false ;
678+
679+ for (mlir::Value mapVar : tarData.getMapVars ()) {
680+ auto mapOp = llvm::cast<mlir::omp::MapInfoOp>(mapVar.getDefiningOp ());
681+ if (mapOp.getVarPtr () == op.getVarPtr () &&
682+ mapOp.getVarPtrPtr () == op.getVarPtrPtr ()) {
683+ return true ;
684+ }
685+ }
686+
687+ return false ;
688+ };
689+
690+ // if we're not a top level descriptor with members (e.g. member of a
691+ // derived type), we do not want to perform this step.
692+ if (!llvm::isa<mlir::omp::TargetDataOp>(target) || op.getMembers ().empty ())
693+ return ;
694+
695+ if (!isUseDeviceAddr (op, target) && !isUseDevicePtr (op, target))
696+ return ;
697+
698+ auto targetDataOp = llvm::cast<mlir::omp::TargetDataOp>(target);
699+ if (explicitMappingPresent (op, targetDataOp))
700+ return ;
701+
702+ mlir::omp::MapInfoOp newDescParentMapOp =
703+ builder.create <mlir::omp::MapInfoOp>(
704+ op->getLoc (), op.getResult ().getType (), op.getVarPtr (),
705+ op.getVarTypeAttr (),
706+ builder.getIntegerAttr (
707+ builder.getIntegerType (64 , false ),
708+ llvm::to_underlying (
709+ llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO |
710+ llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS)),
711+ op.getMapCaptureTypeAttr (), /* varPtrPtr=*/ mlir::Value{},
712+ mlir::SmallVector<mlir::Value>{}, mlir::ArrayAttr{},
713+ /* bounds=*/ mlir::SmallVector<mlir::Value>{},
714+ /* mapperId*/ mlir::FlatSymbolRefAttr (), op.getNameAttr (),
715+ /* partial_map=*/ builder.getBoolAttr (false ));
716+
717+ targetDataOp.getMapVarsMutable ().append ({newDescParentMapOp});
718+ }
719+
720+ void removeTopLevelDescriptor (mlir::omp::MapInfoOp op,
721+ fir::FirOpBuilder &builder,
722+ mlir::Operation *target) {
723+ if (llvm::isa<mlir::omp::TargetOp, mlir::omp::TargetDataOp,
724+ mlir::omp::DeclareMapperInfoOp>(target))
725+ return ;
726+
727+ // if we're not a top level descriptor with members (e.g. member of a
728+ // derived type), we do not want to perform this step.
729+ if (op.getMembers ().empty ())
730+ return ;
731+
732+ mlir::SmallVector<mlir::Value> members = op.getMembers ();
733+ mlir::omp::MapInfoOp baseAddr =
734+ mlir::dyn_cast_or_null<mlir::omp::MapInfoOp>(
735+ members.front ().getDefiningOp ());
736+ assert (baseAddr && " Expected member to be MapInfoOp" );
737+ members.erase (members.begin ());
738+
739+ llvm::SmallVector<llvm::SmallVector<int64_t >> memberIndices;
740+ getMemberIndicesAsVectors (op, memberIndices);
741+
742+ // Can skip the extra processing if there's only 1 member as it'd
743+ // be the base addresses, which we're promoting to the parent.
744+ mlir::ArrayAttr membersAttr;
745+ if (memberIndices.size () > 1 ) {
746+ memberIndices.erase (memberIndices.begin ());
747+ membersAttr = builder.create2DI64ArrayAttr (memberIndices);
748+ }
749+
750+ // VarPtrPtr is tied to detecting if something is a pointer in the later
751+ // lowering currently, this at the moment comes tied with
752+ // OMP_MAP_PTR_AND_OBJ being applied which breaks the problem this tries to
753+ // solve by emitting a 8-byte mapping tied to the descriptor address (even
754+ // if we only emit a single map). So we circumvent this by removing the
755+ // varPtrPtr mapping, however, a side affect of this is we lose the
756+ // additional load from the backend tied to this which is required for
757+ // correctness and getting the correct address of the data to perform our
758+ // mapping. So we do our load at this stage.
759+ // TODO/FIXME: Tidy up the OMP_MAP_PTR_AND_OBJ and varPtrPtr being tied to
760+ // if something is a pointer to try and tidy up the implementation a bit.
761+ // This is an unfortunate complexity from push-back from upstream. We
762+ // could also emit a load at this level for all base addresses as well,
763+ // which in turn will simplify the later lowering a bit as well. But first
764+ // need to see how well this alteration works.
765+ auto loadBaseAddr =
766+ builder.loadIfRef (op->getLoc (), baseAddr.getVarPtrPtr ());
767+ mlir::omp::MapInfoOp newBaseAddrMapOp =
768+ builder.create <mlir::omp::MapInfoOp>(
769+ op->getLoc (), loadBaseAddr.getType (), loadBaseAddr,
770+ baseAddr.getVarTypeAttr (), baseAddr.getMapTypeAttr (),
771+ baseAddr.getMapCaptureTypeAttr (), mlir::Value{}, members,
772+ membersAttr, baseAddr.getBounds (),
773+ /* mapperId*/ mlir::FlatSymbolRefAttr (), op.getNameAttr (),
774+ /* partial_map=*/ builder.getBoolAttr (false ));
775+ op.replaceAllUsesWith (newBaseAddrMapOp.getResult ());
776+ op->erase ();
777+ baseAddr.erase ();
778+ }
779+
596780 // This pass executes on omp::MapInfoOp's containing descriptor based types
597781 // (allocatables, pointers, assumed shape etc.) and expanding them into
598782 // multiple omp::MapInfoOp's for each pointer member contained within the
@@ -622,6 +806,7 @@ class MapInfoFinalizationPass
622806 // clear all local allocations we made for any boxes in any prior
623807 // iterations from previous function scopes.
624808 localBoxAllocas.clear ();
809+ deferrableDesc.clear ();
625810
626811 // First, walk `omp.map.info` ops to see if any of them have varPtrs
627812 // with an underlying type of fir.char<k, ?>, i.e a character
@@ -864,6 +1049,36 @@ class MapInfoFinalizationPass
8641049 }
8651050 });
8661051
1052+ // Now that we've expanded all of our boxes into a descriptor and base
1053+ // address map where necessary, we check if the map owner is an
1054+ // enter/exit/target data directive, and if they are we drop the initial
1055+ // descriptor (top-level parent) and replace it with the
1056+ // base_address/data.
1057+ //
1058+ // This circumvents issues with stack allocated descriptors bound to
1059+ // device colliding which in Flang is rather trivial for a user to do by
1060+ // accident due to the rather pervasive local intermediate descriptor
1061+ // generation that occurs whenever you pass boxes around different scopes.
1062+ // In OpenMP 6+ mapping these would be a user error as the tools required
1063+ // to circumvent these issues are provided by the spec (ref_ptr/ptee map
1064+ // types), but in prior specifications these tools are not available and
1065+ // it becomes an implementation issue for us to solve.
1066+ //
1067+ // We do this by dropping the top-level descriptor which will be the stack
1068+ // descriptor when we perform enter/exit maps, as we don't want these to
1069+ // be bound until necessary which is when we utilise the descriptor type
1070+ // within a target region. At which point we map the relevant descriptor
1071+ // data and the runtime should correctly associate the data with the
1072+ // descriptor and bind together and allow clean mapping and execution.
1073+ for (auto *op : deferrableDesc) {
1074+ auto mapOp = llvm::dyn_cast<mlir::omp::MapInfoOp>(op);
1075+ mlir::Operation *targetUser = getFirstTargetUser (mapOp);
1076+ assert (targetUser && " expected user of map operation was not found" );
1077+ builder.setInsertionPoint (mapOp);
1078+ removeTopLevelDescriptor (mapOp, builder, targetUser);
1079+ addImplicitDescriptorMapToTargetDataOp (mapOp, builder, *targetUser);
1080+ }
1081+
8671082 // Wait until after we have generated all of our maps to add them onto
8681083 // the target's block arguments, simplifying the process as there would be
8691084 // no need to avoid accidental duplicate additions.
0 commit comments