@@ -77,6 +77,10 @@ class MapInfoFinalizationPass
7777 // / | |
7878 std::map<mlir::Operation *, mlir::Value> localBoxAllocas;
7979
80+ // List of deferrable descriptors to process at the end of
81+ // the pass.
82+ llvm::SmallVector<mlir::Operation *> deferrableDesc;
83+
8084 // / Return true if the given path exists in a list of paths.
8185 static bool
8286 containsPath (const llvm::SmallVectorImpl<llvm::SmallVector<int64_t >> &paths,
@@ -183,6 +187,40 @@ class MapInfoFinalizationPass
183187 newMemberIndexPaths.emplace_back (indexPath.begin (), indexPath.end ());
184188 }
185189
190+ // Check if the declaration operation we have refers to a dummy
191+ // function argument.
192+ bool isDummyArgument (mlir::Value mappedValue) {
193+ if (auto declareOp = mlir::dyn_cast_if_present<hlfir::DeclareOp>(
194+ mappedValue.getDefiningOp ()))
195+ if (auto dummyScope = declareOp.getDummyScope ())
196+ return true ;
197+ return false ;
198+ }
199+
200+ // Relevant for OpenMP < 5.2, where attach semantics and rules don't exist.
201+ // As descriptors were an unspoken implementation detail in these versions
202+ // there's certain cases where the user (and the compiler implementation)
203+ // can create data mapping errors by having temporary descriptors stuck
204+ // in memory. The main example is calling an 'target enter data map'
205+ // without a corresponding exit on an assumed shape or size dummy
206+ // argument, a local stack descriptor is generated, gets mapped and
207+ // is then left on device. A user doesn't realize what they've done as
208+ // the OpenMP specification isn't explicit on descriptor handling in
209+ // earlier versions and as far as Fortran is concerned this si something
210+ // hidden from a user. To avoid this we can defer the descriptor mapping
211+ // in these cases until target or target data regions, when we can be
212+ // sure they have a clear limited scope on device.
213+ bool canDeferDescriptorMapping (mlir::Value descriptor) {
214+ if (fir::isAllocatableType (descriptor.getType ()) ||
215+ fir::isPointerType (descriptor.getType ()))
216+ return false ;
217+ if (isDummyArgument (descriptor) &&
218+ (fir::isAssumedType (descriptor.getType ()) ||
219+ fir::isAssumedShape (descriptor.getType ())))
220+ return true ;
221+ return false ;
222+ }
223+
186224 // / getMemberUserList gathers all users of a particular MapInfoOp that are
187225 // / other MapInfoOp's and places them into the mapMemberUsers list, which
188226 // / records the map that the current argument MapInfoOp "op" is part of
@@ -234,13 +272,16 @@ class MapInfoFinalizationPass
234272 // / fir::BoxOffsetOp we utilise to access the descriptor datas
235273 // / base address can be utilised.
236274 mlir::Value getDescriptorFromBoxMap (mlir::omp::MapInfoOp boxMap,
237- fir::FirOpBuilder &builder) {
275+ fir::FirOpBuilder &builder,
276+ bool &canDescBeDeferred) {
238277 mlir::Value descriptor = boxMap.getVarPtr ();
239278 if (!fir::isTypeWithDescriptor (boxMap.getVarType ()))
240279 if (auto addrOp = mlir::dyn_cast_if_present<fir::BoxAddrOp>(
241280 boxMap.getVarPtr ().getDefiningOp ()))
242281 descriptor = addrOp.getVal ();
243282
283+ canDescBeDeferred = canDeferDescriptorMapping (descriptor);
284+
244285 if (!mlir::isa<fir::BaseBoxType>(descriptor.getType ()) &&
245286 !fir::factory::isOptionalArgument (descriptor.getDefiningOp ()))
246287 return descriptor;
@@ -391,8 +432,7 @@ class MapInfoFinalizationPass
391432
392433 // / Check if the mapOp is present in the HasDeviceAddr clause on
393434 // / the userOp. Only applies to TargetOp.
394- bool isHasDeviceAddr (mlir::omp::MapInfoOp mapOp, mlir::Operation *userOp) {
395- assert (userOp && " Expecting non-null argument" );
435+ bool isHasDeviceAddr (mlir::omp::MapInfoOp mapOp, mlir::Operation &userOp) {
396436 if (auto targetOp = llvm::dyn_cast<mlir::omp::TargetOp>(userOp)) {
397437 for (mlir::Value hda : targetOp.getHasDeviceAddrVars ()) {
398438 if (hda.getDefiningOp () == mapOp)
@@ -402,6 +442,26 @@ class MapInfoFinalizationPass
402442 return false ;
403443 }
404444
445+ bool isUseDeviceAddr (mlir::omp::MapInfoOp mapOp, mlir::Operation &userOp) {
446+ if (auto targetDataOp = llvm::dyn_cast<mlir::omp::TargetDataOp>(userOp)) {
447+ for (mlir::Value uda : targetDataOp.getUseDeviceAddrVars ()) {
448+ if (uda.getDefiningOp () == mapOp)
449+ return true ;
450+ }
451+ }
452+ return false ;
453+ }
454+
455+ bool isUseDevicePtr (mlir::omp::MapInfoOp mapOp, mlir::Operation &userOp) {
456+ if (auto targetDataOp = llvm::dyn_cast<mlir::omp::TargetDataOp>(userOp)) {
457+ for (mlir::Value udp : targetDataOp.getUseDevicePtrVars ()) {
458+ if (udp.getDefiningOp () == mapOp)
459+ return true ;
460+ }
461+ }
462+ return false ;
463+ }
464+
405465 mlir::omp::MapInfoOp genBoxcharMemberMap (mlir::omp::MapInfoOp op,
406466 fir::FirOpBuilder &builder) {
407467 if (!op.getMembers ().empty ())
@@ -466,12 +526,14 @@ class MapInfoFinalizationPass
466526
467527 // TODO: map the addendum segment of the descriptor, similarly to the
468528 // base address/data pointer member.
469- mlir::Value descriptor = getDescriptorFromBoxMap (op, builder);
529+ bool descCanBeDeferred = false ;
530+ mlir::Value descriptor =
531+ getDescriptorFromBoxMap (op, builder, descCanBeDeferred);
470532
471533 mlir::ArrayAttr newMembersAttr;
472534 mlir::SmallVector<mlir::Value> newMembers;
473535 llvm::SmallVector<llvm::SmallVector<int64_t >> memberIndices;
474- bool isHasDeviceAddrFlag = isHasDeviceAddr (op, target);
536+ bool isHasDeviceAddrFlag = isHasDeviceAddr (op, * target);
475537
476538 if (!mapMemberUsers.empty () || !op.getMembers ().empty ())
477539 getMemberIndicesAsVectors (
@@ -553,6 +615,10 @@ class MapInfoFinalizationPass
553615 /* partial_map=*/ builder.getBoolAttr (false ));
554616 op.replaceAllUsesWith (newDescParentMapOp.getResult ());
555617 op->erase ();
618+
619+ if (descCanBeDeferred)
620+ deferrableDesc.push_back (newDescParentMapOp);
621+
556622 return newDescParentMapOp;
557623 }
558624
@@ -701,6 +767,124 @@ class MapInfoFinalizationPass
701767 return nullptr ;
702768 }
703769
770+ void addImplicitDescriptorMapToTargetDataOp (mlir::omp::MapInfoOp op,
771+ fir::FirOpBuilder &builder,
772+ mlir::Operation &target) {
773+ // Checks if the map is present as an explicit map already on the target
774+ // data directive, and not just present on a use_device_addr/ptr, as if
775+ // that's the case, we should not need to add an implicit map for the
776+ // descriptor.
777+ auto explicitMappingPresent = [](mlir::omp::MapInfoOp op,
778+ mlir::omp::TargetDataOp tarData) {
779+ // Verify top-level descriptor mapping is at least equal with same
780+ // varPtr, the map type should always be To for a descriptor, which is
781+ // all we really care about for this mapping as we aim to make sure the
782+ // descriptor is always present on device if we're expecting to access
783+ // the underlying data.
784+ if (tarData.getMapVars ().empty ())
785+ return false ;
786+
787+ for (mlir::Value mapVar : tarData.getMapVars ()) {
788+ auto mapOp = llvm::cast<mlir::omp::MapInfoOp>(mapVar.getDefiningOp ());
789+ if (mapOp.getVarPtr () == op.getVarPtr () &&
790+ mapOp.getVarPtrPtr () == op.getVarPtrPtr ()) {
791+ return true ;
792+ }
793+ }
794+
795+ return false ;
796+ };
797+
798+ // if we're not a top level descriptor with members (e.g. member of a
799+ // derived type), we do not want to perform this step.
800+ if (!llvm::isa<mlir::omp::TargetDataOp>(target) || op.getMembers ().empty ())
801+ return ;
802+
803+ if (!isUseDeviceAddr (op, target) && !isUseDevicePtr (op, target))
804+ return ;
805+
806+ auto targetDataOp = llvm::cast<mlir::omp::TargetDataOp>(target);
807+ if (explicitMappingPresent (op, targetDataOp))
808+ return ;
809+
810+ mlir::omp::MapInfoOp newDescParentMapOp =
811+ builder.create <mlir::omp::MapInfoOp>(
812+ op->getLoc (), op.getResult ().getType (), op.getVarPtr (),
813+ op.getVarTypeAttr (),
814+ builder.getIntegerAttr (
815+ builder.getIntegerType (64 , false ),
816+ llvm::to_underlying (
817+ llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO |
818+ llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS)),
819+ op.getMapCaptureTypeAttr (), /* varPtrPtr=*/ mlir::Value{},
820+ mlir::SmallVector<mlir::Value>{}, mlir::ArrayAttr{},
821+ /* bounds=*/ mlir::SmallVector<mlir::Value>{},
822+ /* mapperId*/ mlir::FlatSymbolRefAttr (), op.getNameAttr (),
823+ /* partial_map=*/ builder.getBoolAttr (false ));
824+
825+ targetDataOp.getMapVarsMutable ().append ({newDescParentMapOp});
826+ }
827+
828+ void removeTopLevelDescriptor (mlir::omp::MapInfoOp op,
829+ fir::FirOpBuilder &builder,
830+ mlir::Operation *target) {
831+ if (llvm::isa<mlir::omp::TargetOp, mlir::omp::TargetDataOp,
832+ mlir::omp::DeclareMapperInfoOp>(target))
833+ return ;
834+
835+ // if we're not a top level descriptor with members (e.g. member of a
836+ // derived type), we do not want to perform this step.
837+ if (op.getMembers ().empty ())
838+ return ;
839+
840+ mlir::SmallVector<mlir::Value> members = op.getMembers ();
841+ mlir::omp::MapInfoOp baseAddr =
842+ mlir::dyn_cast_or_null<mlir::omp::MapInfoOp>(
843+ members.front ().getDefiningOp ());
844+ assert (baseAddr && " Expected member to be MapInfoOp" );
845+ members.erase (members.begin ());
846+
847+ llvm::SmallVector<llvm::SmallVector<int64_t >> memberIndices;
848+ getMemberIndicesAsVectors (op, memberIndices);
849+
850+ // Can skip the extra processing if there's only 1 member as it'd
851+ // be the base addresses, which we're promoting to the parent.
852+ mlir::ArrayAttr membersAttr;
853+ if (memberIndices.size () > 1 ) {
854+ memberIndices.erase (memberIndices.begin ());
855+ membersAttr = builder.create2DI64ArrayAttr (memberIndices);
856+ }
857+
858+ // VarPtrPtr is tied to detecting if something is a pointer in the later
859+ // lowering currently, this at the moment comes tied with
860+ // OMP_MAP_PTR_AND_OBJ being applied which breaks the problem this tries to
861+ // solve by emitting a 8-byte mapping tied to the descriptor address (even
862+ // if we only emit a single map). So we circumvent this by removing the
863+ // varPtrPtr mapping, however, a side affect of this is we lose the
864+ // additional load from the backend tied to this which is required for
865+ // correctness and getting the correct address of the data to perform our
866+ // mapping. So we do our load at this stage.
867+ // TODO/FIXME: Tidy up the OMP_MAP_PTR_AND_OBJ and varPtrPtr being tied to
868+ // if something is a pointer to try and tidy up the implementation a bit.
869+ // This is an unfortunate complexity from push-back from upstream. We
870+ // could also emit a load at this level for all base addresses as well,
871+ // which in turn will simplify the later lowering a bit as well. But first
872+ // need to see how well this alteration works.
873+ auto loadBaseAddr =
874+ builder.loadIfRef (op->getLoc (), baseAddr.getVarPtrPtr ());
875+ mlir::omp::MapInfoOp newBaseAddrMapOp =
876+ builder.create <mlir::omp::MapInfoOp>(
877+ op->getLoc (), loadBaseAddr.getType (), loadBaseAddr,
878+ baseAddr.getVarTypeAttr (), baseAddr.getMapTypeAttr (),
879+ baseAddr.getMapCaptureTypeAttr (), mlir::Value{}, members,
880+ membersAttr, baseAddr.getBounds (),
881+ /* mapperId*/ mlir::FlatSymbolRefAttr (), op.getNameAttr (),
882+ /* partial_map=*/ builder.getBoolAttr (false ));
883+ op.replaceAllUsesWith (newBaseAddrMapOp.getResult ());
884+ op->erase ();
885+ baseAddr.erase ();
886+ }
887+
704888 // This pass executes on omp::MapInfoOp's containing descriptor based types
705889 // (allocatables, pointers, assumed shape etc.) and expanding them into
706890 // multiple omp::MapInfoOp's for each pointer member contained within the
@@ -730,6 +914,7 @@ class MapInfoFinalizationPass
730914 // clear all local allocations we made for any boxes in any prior
731915 // iterations from previous function scopes.
732916 localBoxAllocas.clear ();
917+ deferrableDesc.clear ();
733918
734919 // First, walk `omp.map.info` ops to see if any of them have varPtrs
735920 // with an underlying type of fir.char<k, ?>, i.e a character
@@ -1010,6 +1195,36 @@ class MapInfoFinalizationPass
10101195 }
10111196 });
10121197
1198+ // Now that we've expanded all of our boxes into a descriptor and base
1199+ // address map where necessary, we check if the map owner is an
1200+ // enter/exit/target data directive, and if they are we drop the initial
1201+ // descriptor (top-level parent) and replace it with the
1202+ // base_address/data.
1203+ //
1204+ // This circumvents issues with stack allocated descriptors bound to
1205+ // device colliding which in Flang is rather trivial for a user to do by
1206+ // accident due to the rather pervasive local intermediate descriptor
1207+ // generation that occurs whenever you pass boxes around different scopes.
1208+ // In OpenMP 6+ mapping these would be a user error as the tools required
1209+ // to circumvent these issues are provided by the spec (ref_ptr/ptee map
1210+ // types), but in prior specifications these tools are not available and
1211+ // it becomes an implementation issue for us to solve.
1212+ //
1213+ // We do this by dropping the top-level descriptor which will be the stack
1214+ // descriptor when we perform enter/exit maps, as we don't want these to
1215+ // be bound until necessary which is when we utilise the descriptor type
1216+ // within a target region. At which point we map the relevant descriptor
1217+ // data and the runtime should correctly associate the data with the
1218+ // descriptor and bind together and allow clean mapping and execution.
1219+ for (auto *op : deferrableDesc) {
1220+ auto mapOp = llvm::dyn_cast<mlir::omp::MapInfoOp>(op);
1221+ mlir::Operation *targetUser = getFirstTargetUser (mapOp);
1222+ assert (targetUser && " expected user of map operation was not found" );
1223+ builder.setInsertionPoint (mapOp);
1224+ removeTopLevelDescriptor (mapOp, builder, targetUser);
1225+ addImplicitDescriptorMapToTargetDataOp (mapOp, builder, *targetUser);
1226+ }
1227+
10131228 // Wait until after we have generated all of our maps to add them onto
10141229 // the target's block arguments, simplifying the process as there would be
10151230 // no need to avoid accidental duplicate additions.
0 commit comments