@@ -77,6 +77,10 @@ class MapInfoFinalizationPass
77
77
// / | |
78
78
std::map<mlir::Operation *, mlir::Value> localBoxAllocas;
79
79
80
+ // List of deferrable descriptors to process at the end of
81
+ // the pass.
82
+ llvm::SmallVector<mlir::Operation *> deferrableDesc;
83
+
80
84
// / Return true if the given path exists in a list of paths.
81
85
static bool
82
86
containsPath (const llvm::SmallVectorImpl<llvm::SmallVector<int64_t >> &paths,
@@ -183,6 +187,40 @@ class MapInfoFinalizationPass
183
187
newMemberIndexPaths.emplace_back (indexPath.begin (), indexPath.end ());
184
188
}
185
189
190
+ // Check if the declaration operation we have refers to a dummy
191
+ // function argument.
192
+ bool isDummyArgument (mlir::Value mappedValue) {
193
+ if (auto declareOp = mlir::dyn_cast_if_present<hlfir::DeclareOp>(
194
+ mappedValue.getDefiningOp ()))
195
+ if (auto dummyScope = declareOp.getDummyScope ())
196
+ return true ;
197
+ return false ;
198
+ }
199
+
200
+ // Relevant for OpenMP < 5.2, where attach semantics and rules don't exist.
201
+ // As descriptors were an unspoken implementation detail in these versions
202
+ // there's certain cases where the user (and the compiler implementation)
203
+ // can create data mapping errors by having temporary descriptors stuck
204
+ // in memory. The main example is calling an 'target enter data map'
205
+ // without a corresponding exit on an assumed shape or size dummy
206
+ // argument, a local stack descriptor is generated, gets mapped and
207
+ // is then left on device. A user doesn't realize what they've done as
208
+ // the OpenMP specification isn't explicit on descriptor handling in
209
+ // earlier versions and as far as Fortran is concerned this si something
210
+ // hidden from a user. To avoid this we can defer the descriptor mapping
211
+ // in these cases until target or target data regions, when we can be
212
+ // sure they have a clear limited scope on device.
213
+ bool canDeferDescriptorMapping (mlir::Value descriptor) {
214
+ if (fir::isAllocatableType (descriptor.getType ()) ||
215
+ fir::isPointerType (descriptor.getType ()))
216
+ return false ;
217
+ if (isDummyArgument (descriptor) &&
218
+ (fir::isAssumedType (descriptor.getType ()) ||
219
+ fir::isAssumedShape (descriptor.getType ())))
220
+ return true ;
221
+ return false ;
222
+ }
223
+
186
224
// / getMemberUserList gathers all users of a particular MapInfoOp that are
187
225
// / other MapInfoOp's and places them into the mapMemberUsers list, which
188
226
// / records the map that the current argument MapInfoOp "op" is part of
@@ -234,13 +272,16 @@ class MapInfoFinalizationPass
234
272
// / fir::BoxOffsetOp we utilise to access the descriptor datas
235
273
// / base address can be utilised.
236
274
mlir::Value getDescriptorFromBoxMap (mlir::omp::MapInfoOp boxMap,
237
- fir::FirOpBuilder &builder) {
275
+ fir::FirOpBuilder &builder,
276
+ bool &canDescBeDeferred) {
238
277
mlir::Value descriptor = boxMap.getVarPtr ();
239
278
if (!fir::isTypeWithDescriptor (boxMap.getVarType ()))
240
279
if (auto addrOp = mlir::dyn_cast_if_present<fir::BoxAddrOp>(
241
280
boxMap.getVarPtr ().getDefiningOp ()))
242
281
descriptor = addrOp.getVal ();
243
282
283
+ canDescBeDeferred = canDeferDescriptorMapping (descriptor);
284
+
244
285
if (!mlir::isa<fir::BaseBoxType>(descriptor.getType ()) &&
245
286
!fir::factory::isOptionalArgument (descriptor.getDefiningOp ()))
246
287
return descriptor;
@@ -391,8 +432,7 @@ class MapInfoFinalizationPass
391
432
392
433
// / Check if the mapOp is present in the HasDeviceAddr clause on
393
434
// / the userOp. Only applies to TargetOp.
394
- bool isHasDeviceAddr (mlir::omp::MapInfoOp mapOp, mlir::Operation *userOp) {
395
- assert (userOp && " Expecting non-null argument" );
435
+ bool isHasDeviceAddr (mlir::omp::MapInfoOp mapOp, mlir::Operation &userOp) {
396
436
if (auto targetOp = llvm::dyn_cast<mlir::omp::TargetOp>(userOp)) {
397
437
for (mlir::Value hda : targetOp.getHasDeviceAddrVars ()) {
398
438
if (hda.getDefiningOp () == mapOp)
@@ -402,6 +442,26 @@ class MapInfoFinalizationPass
402
442
return false ;
403
443
}
404
444
445
+ bool isUseDeviceAddr (mlir::omp::MapInfoOp mapOp, mlir::Operation &userOp) {
446
+ if (auto targetDataOp = llvm::dyn_cast<mlir::omp::TargetDataOp>(userOp)) {
447
+ for (mlir::Value uda : targetDataOp.getUseDeviceAddrVars ()) {
448
+ if (uda.getDefiningOp () == mapOp)
449
+ return true ;
450
+ }
451
+ }
452
+ return false ;
453
+ }
454
+
455
+ bool isUseDevicePtr (mlir::omp::MapInfoOp mapOp, mlir::Operation &userOp) {
456
+ if (auto targetDataOp = llvm::dyn_cast<mlir::omp::TargetDataOp>(userOp)) {
457
+ for (mlir::Value udp : targetDataOp.getUseDevicePtrVars ()) {
458
+ if (udp.getDefiningOp () == mapOp)
459
+ return true ;
460
+ }
461
+ }
462
+ return false ;
463
+ }
464
+
405
465
mlir::omp::MapInfoOp genBoxcharMemberMap (mlir::omp::MapInfoOp op,
406
466
fir::FirOpBuilder &builder) {
407
467
if (!op.getMembers ().empty ())
@@ -466,12 +526,14 @@ class MapInfoFinalizationPass
466
526
467
527
// TODO: map the addendum segment of the descriptor, similarly to the
468
528
// base address/data pointer member.
469
- mlir::Value descriptor = getDescriptorFromBoxMap (op, builder);
529
+ bool descCanBeDeferred = false ;
530
+ mlir::Value descriptor =
531
+ getDescriptorFromBoxMap (op, builder, descCanBeDeferred);
470
532
471
533
mlir::ArrayAttr newMembersAttr;
472
534
mlir::SmallVector<mlir::Value> newMembers;
473
535
llvm::SmallVector<llvm::SmallVector<int64_t >> memberIndices;
474
- bool isHasDeviceAddrFlag = isHasDeviceAddr (op, target);
536
+ bool isHasDeviceAddrFlag = isHasDeviceAddr (op, * target);
475
537
476
538
if (!mapMemberUsers.empty () || !op.getMembers ().empty ())
477
539
getMemberIndicesAsVectors (
@@ -553,6 +615,10 @@ class MapInfoFinalizationPass
553
615
/* partial_map=*/ builder.getBoolAttr (false ));
554
616
op.replaceAllUsesWith (newDescParentMapOp.getResult ());
555
617
op->erase ();
618
+
619
+ if (descCanBeDeferred)
620
+ deferrableDesc.push_back (newDescParentMapOp);
621
+
556
622
return newDescParentMapOp;
557
623
}
558
624
@@ -701,6 +767,124 @@ class MapInfoFinalizationPass
701
767
return nullptr ;
702
768
}
703
769
770
+ void addImplicitDescriptorMapToTargetDataOp (mlir::omp::MapInfoOp op,
771
+ fir::FirOpBuilder &builder,
772
+ mlir::Operation &target) {
773
+ // Checks if the map is present as an explicit map already on the target
774
+ // data directive, and not just present on a use_device_addr/ptr, as if
775
+ // that's the case, we should not need to add an implicit map for the
776
+ // descriptor.
777
+ auto explicitMappingPresent = [](mlir::omp::MapInfoOp op,
778
+ mlir::omp::TargetDataOp tarData) {
779
+ // Verify top-level descriptor mapping is at least equal with same
780
+ // varPtr, the map type should always be To for a descriptor, which is
781
+ // all we really care about for this mapping as we aim to make sure the
782
+ // descriptor is always present on device if we're expecting to access
783
+ // the underlying data.
784
+ if (tarData.getMapVars ().empty ())
785
+ return false ;
786
+
787
+ for (mlir::Value mapVar : tarData.getMapVars ()) {
788
+ auto mapOp = llvm::cast<mlir::omp::MapInfoOp>(mapVar.getDefiningOp ());
789
+ if (mapOp.getVarPtr () == op.getVarPtr () &&
790
+ mapOp.getVarPtrPtr () == op.getVarPtrPtr ()) {
791
+ return true ;
792
+ }
793
+ }
794
+
795
+ return false ;
796
+ };
797
+
798
+ // if we're not a top level descriptor with members (e.g. member of a
799
+ // derived type), we do not want to perform this step.
800
+ if (!llvm::isa<mlir::omp::TargetDataOp>(target) || op.getMembers ().empty ())
801
+ return ;
802
+
803
+ if (!isUseDeviceAddr (op, target) && !isUseDevicePtr (op, target))
804
+ return ;
805
+
806
+ auto targetDataOp = llvm::cast<mlir::omp::TargetDataOp>(target);
807
+ if (explicitMappingPresent (op, targetDataOp))
808
+ return ;
809
+
810
+ mlir::omp::MapInfoOp newDescParentMapOp =
811
+ builder.create <mlir::omp::MapInfoOp>(
812
+ op->getLoc (), op.getResult ().getType (), op.getVarPtr (),
813
+ op.getVarTypeAttr (),
814
+ builder.getIntegerAttr (
815
+ builder.getIntegerType (64 , false ),
816
+ llvm::to_underlying (
817
+ llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO |
818
+ llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS)),
819
+ op.getMapCaptureTypeAttr (), /* varPtrPtr=*/ mlir::Value{},
820
+ mlir::SmallVector<mlir::Value>{}, mlir::ArrayAttr{},
821
+ /* bounds=*/ mlir::SmallVector<mlir::Value>{},
822
+ /* mapperId*/ mlir::FlatSymbolRefAttr (), op.getNameAttr (),
823
+ /* partial_map=*/ builder.getBoolAttr (false ));
824
+
825
+ targetDataOp.getMapVarsMutable ().append ({newDescParentMapOp});
826
+ }
827
+
828
+ void removeTopLevelDescriptor (mlir::omp::MapInfoOp op,
829
+ fir::FirOpBuilder &builder,
830
+ mlir::Operation *target) {
831
+ if (llvm::isa<mlir::omp::TargetOp, mlir::omp::TargetDataOp,
832
+ mlir::omp::DeclareMapperInfoOp>(target))
833
+ return ;
834
+
835
+ // if we're not a top level descriptor with members (e.g. member of a
836
+ // derived type), we do not want to perform this step.
837
+ if (op.getMembers ().empty ())
838
+ return ;
839
+
840
+ mlir::SmallVector<mlir::Value> members = op.getMembers ();
841
+ mlir::omp::MapInfoOp baseAddr =
842
+ mlir::dyn_cast_or_null<mlir::omp::MapInfoOp>(
843
+ members.front ().getDefiningOp ());
844
+ assert (baseAddr && " Expected member to be MapInfoOp" );
845
+ members.erase (members.begin ());
846
+
847
+ llvm::SmallVector<llvm::SmallVector<int64_t >> memberIndices;
848
+ getMemberIndicesAsVectors (op, memberIndices);
849
+
850
+ // Can skip the extra processing if there's only 1 member as it'd
851
+ // be the base addresses, which we're promoting to the parent.
852
+ mlir::ArrayAttr membersAttr;
853
+ if (memberIndices.size () > 1 ) {
854
+ memberIndices.erase (memberIndices.begin ());
855
+ membersAttr = builder.create2DI64ArrayAttr (memberIndices);
856
+ }
857
+
858
+ // VarPtrPtr is tied to detecting if something is a pointer in the later
859
+ // lowering currently, this at the moment comes tied with
860
+ // OMP_MAP_PTR_AND_OBJ being applied which breaks the problem this tries to
861
+ // solve by emitting a 8-byte mapping tied to the descriptor address (even
862
+ // if we only emit a single map). So we circumvent this by removing the
863
+ // varPtrPtr mapping, however, a side affect of this is we lose the
864
+ // additional load from the backend tied to this which is required for
865
+ // correctness and getting the correct address of the data to perform our
866
+ // mapping. So we do our load at this stage.
867
+ // TODO/FIXME: Tidy up the OMP_MAP_PTR_AND_OBJ and varPtrPtr being tied to
868
+ // if something is a pointer to try and tidy up the implementation a bit.
869
+ // This is an unfortunate complexity from push-back from upstream. We
870
+ // could also emit a load at this level for all base addresses as well,
871
+ // which in turn will simplify the later lowering a bit as well. But first
872
+ // need to see how well this alteration works.
873
+ auto loadBaseAddr =
874
+ builder.loadIfRef (op->getLoc (), baseAddr.getVarPtrPtr ());
875
+ mlir::omp::MapInfoOp newBaseAddrMapOp =
876
+ builder.create <mlir::omp::MapInfoOp>(
877
+ op->getLoc (), loadBaseAddr.getType (), loadBaseAddr,
878
+ baseAddr.getVarTypeAttr (), baseAddr.getMapTypeAttr (),
879
+ baseAddr.getMapCaptureTypeAttr (), mlir::Value{}, members,
880
+ membersAttr, baseAddr.getBounds (),
881
+ /* mapperId*/ mlir::FlatSymbolRefAttr (), op.getNameAttr (),
882
+ /* partial_map=*/ builder.getBoolAttr (false ));
883
+ op.replaceAllUsesWith (newBaseAddrMapOp.getResult ());
884
+ op->erase ();
885
+ baseAddr.erase ();
886
+ }
887
+
704
888
// This pass executes on omp::MapInfoOp's containing descriptor based types
705
889
// (allocatables, pointers, assumed shape etc.) and expanding them into
706
890
// multiple omp::MapInfoOp's for each pointer member contained within the
@@ -730,6 +914,7 @@ class MapInfoFinalizationPass
730
914
// clear all local allocations we made for any boxes in any prior
731
915
// iterations from previous function scopes.
732
916
localBoxAllocas.clear ();
917
+ deferrableDesc.clear ();
733
918
734
919
// First, walk `omp.map.info` ops to see if any of them have varPtrs
735
920
// with an underlying type of fir.char<k, ?>, i.e a character
@@ -1010,6 +1195,36 @@ class MapInfoFinalizationPass
1010
1195
}
1011
1196
});
1012
1197
1198
+ // Now that we've expanded all of our boxes into a descriptor and base
1199
+ // address map where necessary, we check if the map owner is an
1200
+ // enter/exit/target data directive, and if they are we drop the initial
1201
+ // descriptor (top-level parent) and replace it with the
1202
+ // base_address/data.
1203
+ //
1204
+ // This circumvents issues with stack allocated descriptors bound to
1205
+ // device colliding which in Flang is rather trivial for a user to do by
1206
+ // accident due to the rather pervasive local intermediate descriptor
1207
+ // generation that occurs whenever you pass boxes around different scopes.
1208
+ // In OpenMP 6+ mapping these would be a user error as the tools required
1209
+ // to circumvent these issues are provided by the spec (ref_ptr/ptee map
1210
+ // types), but in prior specifications these tools are not available and
1211
+ // it becomes an implementation issue for us to solve.
1212
+ //
1213
+ // We do this by dropping the top-level descriptor which will be the stack
1214
+ // descriptor when we perform enter/exit maps, as we don't want these to
1215
+ // be bound until necessary which is when we utilise the descriptor type
1216
+ // within a target region. At which point we map the relevant descriptor
1217
+ // data and the runtime should correctly associate the data with the
1218
+ // descriptor and bind together and allow clean mapping and execution.
1219
+ for (auto *op : deferrableDesc) {
1220
+ auto mapOp = llvm::dyn_cast<mlir::omp::MapInfoOp>(op);
1221
+ mlir::Operation *targetUser = getFirstTargetUser (mapOp);
1222
+ assert (targetUser && " expected user of map operation was not found" );
1223
+ builder.setInsertionPoint (mapOp);
1224
+ removeTopLevelDescriptor (mapOp, builder, targetUser);
1225
+ addImplicitDescriptorMapToTargetDataOp (mapOp, builder, *targetUser);
1226
+ }
1227
+
1013
1228
// Wait until after we have generated all of our maps to add them onto
1014
1229
// the target's block arguments, simplifying the process as there would be
1015
1230
// no need to avoid accidental duplicate additions.
0 commit comments