@@ -544,6 +544,66 @@ class MapInfoFinalizationPass
544
544
return nullptr ;
545
545
}
546
546
547
+ void removeTopLevelDescriptor (mlir::omp::MapInfoOp op,
548
+ fir::FirOpBuilder &builder,
549
+ mlir::Operation *target) {
550
+ if (llvm::isa<mlir::omp::TargetOp, mlir::omp::TargetDataOp,
551
+ mlir::omp::DeclareMapperInfoOp>(target))
552
+ return ;
553
+
554
+ // if we're not a top level descriptor with members (e.g. member of a
555
+ // derived type), we do not want to perform this step.
556
+ if (op.getMembers ().empty ())
557
+ return ;
558
+
559
+ mlir::SmallVector<mlir::Value> newMembers = op.getMembers ();
560
+ mlir::omp::MapInfoOp baseAddr =
561
+ mlir::dyn_cast_or_null<mlir::omp::MapInfoOp>(
562
+ newMembers.front ().getDefiningOp ());
563
+ assert (baseAddr && " Expected member to be MapInfoOp" );
564
+ newMembers.erase (newMembers.begin ());
565
+
566
+ llvm::SmallVector<llvm::SmallVector<int64_t >> memberIndices;
567
+ getMemberIndicesAsVectors (op, memberIndices);
568
+
569
+ // Can skip the extra processing if there's only 1 member as it'd
570
+ // be the base addresses, which we're promoting to the parent.
571
+ mlir::ArrayAttr newMembersAttr;
572
+ if (memberIndices.size () > 1 ) {
573
+ memberIndices.erase (memberIndices.begin ());
574
+ newMembersAttr = builder.create2DI64ArrayAttr (memberIndices);
575
+ }
576
+
577
+ // VarPtrPtr is tied to detecting if something is a pointer in the later
578
+ // lowering currently, this at the moment comes tied with
579
+ // OMP_MAP_PTR_AND_OBJ being applied which breaks the problem this tries to
580
+ // solve by emitting a 8-byte mapping tied to the descriptor address (even
581
+ // if we only emit a single map). So we circumvent this by removing the
582
+ // varPtrPtr mapping, however, a side affect of this is we lose the
583
+ // additional load from the backend tied to this which is required for
584
+ // correctness and getting the correct address of the data to perform our
585
+ // mapping. So we do our load at this stage.
586
+ // TODO/FIXME: Tidy up the OMP_MAP_PTR_AND_OBJ and varPtrPtr being tied to
587
+ // if something is a pointer to try and tidy up the implementation a bit.
588
+ // This is an unfortunate complexity from push-back from upstream. We
589
+ // could also emit a load at this level for all base addresses as well,
590
+ // which in turn will simplify the later lowering a bit as well. But first
591
+ // need to see how well this alteration works.
592
+ auto loadBaseAddr =
593
+ builder.loadIfRef (op->getLoc (), baseAddr.getVarPtrPtr ());
594
+ mlir::omp::MapInfoOp newBaseAddrMapOp =
595
+ builder.create <mlir::omp::MapInfoOp>(
596
+ op->getLoc (), loadBaseAddr.getType (), loadBaseAddr,
597
+ baseAddr.getVarTypeAttr (), baseAddr.getMapTypeAttr (),
598
+ baseAddr.getMapCaptureTypeAttr (), mlir::Value{}, newMembers,
599
+ newMembersAttr, baseAddr.getBounds (),
600
+ /* mapperId*/ mlir::FlatSymbolRefAttr (), op.getNameAttr (),
601
+ /* partial_map=*/ builder.getBoolAttr (false ));
602
+ op.replaceAllUsesWith (newBaseAddrMapOp.getResult ());
603
+ op->erase ();
604
+ baseAddr.erase ();
605
+ }
606
+
547
607
// This pass executes on omp::MapInfoOp's containing descriptor based types
548
608
// (allocatables, pointers, assumed shape etc.) and expanding them into
549
609
// multiple omp::MapInfoOp's for each pointer member contained within the
@@ -791,6 +851,38 @@ class MapInfoFinalizationPass
791
851
}
792
852
});
793
853
854
+ // Now that we've expanded all of our boxes into a descriptor and base
855
+ // address map where necessary, we check if the map owner is an
856
+ // enter/exit/target data directive, and if they are we drop the initial
857
+ // descriptor (top-level parent) and replace it with the
858
+ // base_address/data.
859
+ //
860
+ // This circumvents issues with stack allocated descriptors bound to
861
+ // device colliding which in Flang is rather trivial for a user to do by
862
+ // accident due to the rather pervasive local intermediate descriptor
863
+ // generation that occurs whenever you pass boxes around different scopes.
864
+ // In OpenMP 6+ mapping these would be a user error as the tools required
865
+ // to circumvent these issues are provided by the spec (ref_ptr/ptee map
866
+ // types), but in prior specifications these tools are not available and
867
+ // it becomes an implementation issue for us to solve.
868
+ //
869
+ // We do this by dropping the top-level descriptor which will be the stack
870
+ // descriptor when we perform enter/exit maps, as we don't want these to
871
+ // be bound until necessary which is when we utilise the descriptor type
872
+ // within a target region. At which point we map the relevant descriptor
873
+ // data and the runtime should correctly associate the data with the
874
+ // descriptor and bind together and allow clean mapping and execution.
875
+ func->walk ([&](mlir::omp::MapInfoOp op) {
876
+ if (fir::isTypeWithDescriptor (op.getVarType ()) ||
877
+ mlir::isa_and_present<fir::BoxAddrOp>(
878
+ op.getVarPtr ().getDefiningOp ())) {
879
+ mlir::Operation *targetUser = getFirstTargetUser (op);
880
+ assert (targetUser && " expected user of map operation was not found" );
881
+ builder.setInsertionPoint (op);
882
+ removeTopLevelDescriptor (op, builder, targetUser);
883
+ }
884
+ });
885
+
794
886
// Wait until after we have generated all of our maps to add them onto
795
887
// the target's block arguments, simplifying the process as there would be
796
888
// no need to avoid accidental duplicate additions.
0 commit comments