@@ -544,6 +544,66 @@ class MapInfoFinalizationPass
544544 return nullptr ;
545545 }
546546
547+ void removeTopLevelDescriptor (mlir::omp::MapInfoOp op,
548+ fir::FirOpBuilder &builder,
549+ mlir::Operation *target) {
550+ if (llvm::isa<mlir::omp::TargetOp, mlir::omp::TargetDataOp,
551+ mlir::omp::DeclareMapperInfoOp>(target))
552+ return ;
553+
554+ // if we're not a top level descriptor with members (e.g. member of a
555+ // derived type), we do not want to perform this step.
556+ if (op.getMembers ().empty ())
557+ return ;
558+
559+ mlir::SmallVector<mlir::Value> newMembers = op.getMembers ();
560+ mlir::omp::MapInfoOp baseAddr =
561+ mlir::dyn_cast_or_null<mlir::omp::MapInfoOp>(
562+ newMembers.front ().getDefiningOp ());
563+ assert (baseAddr && " Expected member to be MapInfoOp" );
564+ newMembers.erase (newMembers.begin ());
565+
566+ llvm::SmallVector<llvm::SmallVector<int64_t >> memberIndices;
567+ getMemberIndicesAsVectors (op, memberIndices);
568+
569+ // Can skip the extra processing if there's only 1 member as it'd
570+ // be the base addresses, which we're promoting to the parent.
571+ mlir::ArrayAttr newMembersAttr;
572+ if (memberIndices.size () > 1 ) {
573+ memberIndices.erase (memberIndices.begin ());
574+ newMembersAttr = builder.create2DI64ArrayAttr (memberIndices);
575+ }
576+
577+ // VarPtrPtr is tied to detecting if something is a pointer in the later
578+ // lowering currently, this at the moment comes tied with
579+ // OMP_MAP_PTR_AND_OBJ being applied which breaks the problem this tries to
580+ // solve by emitting a 8-byte mapping tied to the descriptor address (even
581+ // if we only emit a single map). So we circumvent this by removing the
582+ // varPtrPtr mapping, however, a side affect of this is we lose the
583+ // additional load from the backend tied to this which is required for
584+ // correctness and getting the correct address of the data to perform our
585+ // mapping. So we do our load at this stage.
586+ // TODO/FIXME: Tidy up the OMP_MAP_PTR_AND_OBJ and varPtrPtr being tied to
587+ // if something is a pointer to try and tidy up the implementation a bit.
588+ // This is an unfortunate complexity from push-back from upstream. We
589+ // could also emit a load at this level for all base addresses as well,
590+ // which in turn will simplify the later lowering a bit as well. But first
591+ // need to see how well this alteration works.
592+ auto loadBaseAddr =
593+ builder.loadIfRef (op->getLoc (), baseAddr.getVarPtrPtr ());
594+ mlir::omp::MapInfoOp newBaseAddrMapOp =
595+ builder.create <mlir::omp::MapInfoOp>(
596+ op->getLoc (), loadBaseAddr.getType (), loadBaseAddr,
597+ baseAddr.getVarTypeAttr (), baseAddr.getMapTypeAttr (),
598+ baseAddr.getMapCaptureTypeAttr (), mlir::Value{}, newMembers,
599+ newMembersAttr, baseAddr.getBounds (),
600+ /* mapperId*/ mlir::FlatSymbolRefAttr (), op.getNameAttr (),
601+ /* partial_map=*/ builder.getBoolAttr (false ));
602+ op.replaceAllUsesWith (newBaseAddrMapOp.getResult ());
603+ op->erase ();
604+ baseAddr.erase ();
605+ }
606+
547607 // This pass executes on omp::MapInfoOp's containing descriptor based types
548608 // (allocatables, pointers, assumed shape etc.) and expanding them into
549609 // multiple omp::MapInfoOp's for each pointer member contained within the
@@ -791,6 +851,38 @@ class MapInfoFinalizationPass
791851 }
792852 });
793853
854+ // Now that we've expanded all of our boxes into a descriptor and base
855+ // address map where necessary, we check if the map owner is an
856+ // enter/exit/target data directive, and if they are we drop the initial
857+ // descriptor (top-level parent) and replace it with the
858+ // base_address/data.
859+ //
860+ // This circumvents issues with stack allocated descriptors bound to
861+ // device colliding which in Flang is rather trivial for a user to do by
862+ // accident due to the rather pervasive local intermediate descriptor
863+ // generation that occurs whenever you pass boxes around different scopes.
864+ // In OpenMP 6+ mapping these would be a user error as the tools required
865+ // to circumvent these issues are provided by the spec (ref_ptr/ptee map
866+ // types), but in prior specifications these tools are not available and
867+ // it becomes an implementation issue for us to solve.
868+ //
869+ // We do this by dropping the top-level descriptor which will be the stack
870+ // descriptor when we perform enter/exit maps, as we don't want these to
871+ // be bound until necessary which is when we utilise the descriptor type
872+ // within a target region. At which point we map the relevant descriptor
873+ // data and the runtime should correctly associate the data with the
874+ // descriptor and bind together and allow clean mapping and execution.
875+ func->walk ([&](mlir::omp::MapInfoOp op) {
876+ if (fir::isTypeWithDescriptor (op.getVarType ()) ||
877+ mlir::isa_and_present<fir::BoxAddrOp>(
878+ op.getVarPtr ().getDefiningOp ())) {
879+ mlir::Operation *targetUser = getFirstTargetUser (op);
880+ assert (targetUser && " expected user of map operation was not found" );
881+ builder.setInsertionPoint (op);
882+ removeTopLevelDescriptor (op, builder, targetUser);
883+ }
884+ });
885+
794886 // Wait until after we have generated all of our maps to add them onto
795887 // the target's block arguments, simplifying the process as there would be
796888 // no need to avoid accidental duplicate additions.
0 commit comments