Skip to content

Commit d3d6cb1

Browse files
committed
[Flang][OpenMP] Defer descriptor mapping for assumed dummy argument types
This PR adds deferral of descriptor maps until they are neccessary for assumed dummy argument types. The intent is to avoid a problem where a user can inadvertently map a temporary local descriptor to device without their knowledge and proceed to never unmap it. This temporary local descriptor remains lodged in OpenMP device memory and the next time another variable or descriptor residing in the same stack address is mapped we incur a runtime OpenMP map error as we try to remap the same address. This fix was discussed with the OpenMP committee and applies to OpenMP 5.2 and below, future versions of OpenMP can avoid this issue via the attatch semantics added to the specification.
1 parent a22e966 commit d3d6cb1

File tree

4 files changed

+420
-8
lines changed

4 files changed

+420
-8
lines changed

flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp

Lines changed: 220 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,44 @@ class MapInfoFinalizationPass
7575
/// | |
7676
std::map<mlir::Operation *, mlir::Value> localBoxAllocas;
7777

78+
// List of deferrable descriptors to process at the end of
79+
// the pass.
80+
llvm::SmallVector<mlir::Operation *> deferrableDesc;
81+
82+
// Check if the declaration operation we have refers to a dummy
83+
// function argument.
84+
bool isDummyArgument(mlir::Value mappedValue) {
85+
if (auto declareOp = mlir::dyn_cast_if_present<hlfir::DeclareOp>(
86+
mappedValue.getDefiningOp()))
87+
if (auto dummyScope = declareOp.getDummyScope())
88+
return true;
89+
return false;
90+
}
91+
92+
// Relevant for OpenMP < 5.2, where attach semantics and rules don't exist.
93+
// As descriptors were an unspoken implementation detail in these versions
94+
// there's certain cases where the user (and the compiler implementation)
95+
// can create data mapping errors by having temporary descriptors stuck
96+
// in memory. The main example is calling an 'target enter data map'
97+
// without a corresponding exit on an assumed shape or size dummy
98+
// argument, a local stack descriptor is generated, gets mapped and
99+
// is then left on device. A user doesn't realize what they've done as
100+
// the OpenMP specification isn't explicit on descriptor handling in
101+
// earlier versions and as far as Fortran is concerned this si something
102+
// hidden from a user. To avoid this we can defer the descriptor mapping
103+
// in these cases until target or target data regions, when we can be
104+
// sure they have a clear limited scope on device.
105+
bool canDeferDescriptorMapping(mlir::Value descriptor) {
106+
if (fir::isAllocatableType(descriptor.getType()) ||
107+
fir::isPointerType(descriptor.getType()))
108+
return false;
109+
if (isDummyArgument(descriptor) &&
110+
(fir::isAssumedType(descriptor.getType()) ||
111+
fir::isAssumedShape(descriptor.getType())))
112+
return true;
113+
return false;
114+
}
115+
78116
/// getMemberUserList gathers all users of a particular MapInfoOp that are
79117
/// other MapInfoOp's and places them into the mapMemberUsers list, which
80118
/// records the map that the current argument MapInfoOp "op" is part of
@@ -126,13 +164,16 @@ class MapInfoFinalizationPass
126164
/// fir::BoxOffsetOp we utilise to access the descriptor datas
127165
/// base address can be utilised.
128166
mlir::Value getDescriptorFromBoxMap(mlir::omp::MapInfoOp boxMap,
129-
fir::FirOpBuilder &builder) {
167+
fir::FirOpBuilder &builder,
168+
bool &canDescBeDeferred) {
130169
mlir::Value descriptor = boxMap.getVarPtr();
131170
if (!fir::isTypeWithDescriptor(boxMap.getVarType()))
132171
if (auto addrOp = mlir::dyn_cast_if_present<fir::BoxAddrOp>(
133172
boxMap.getVarPtr().getDefiningOp()))
134173
descriptor = addrOp.getVal();
135174

175+
canDescBeDeferred = canDeferDescriptorMapping(descriptor);
176+
136177
if (!mlir::isa<fir::BaseBoxType>(descriptor.getType()) &&
137178
!fir::factory::isOptionalArgument(descriptor.getDefiningOp()))
138179
return descriptor;
@@ -283,8 +324,7 @@ class MapInfoFinalizationPass
283324

284325
/// Check if the mapOp is present in the HasDeviceAddr clause on
285326
/// the userOp. Only applies to TargetOp.
286-
bool isHasDeviceAddr(mlir::omp::MapInfoOp mapOp, mlir::Operation *userOp) {
287-
assert(userOp && "Expecting non-null argument");
327+
bool isHasDeviceAddr(mlir::omp::MapInfoOp mapOp, mlir::Operation &userOp) {
288328
if (auto targetOp = llvm::dyn_cast<mlir::omp::TargetOp>(userOp)) {
289329
for (mlir::Value hda : targetOp.getHasDeviceAddrVars()) {
290330
if (hda.getDefiningOp() == mapOp)
@@ -294,6 +334,26 @@ class MapInfoFinalizationPass
294334
return false;
295335
}
296336

337+
bool isUseDeviceAddr(mlir::omp::MapInfoOp mapOp, mlir::Operation &userOp) {
338+
if (auto targetDataOp = llvm::dyn_cast<mlir::omp::TargetDataOp>(userOp)) {
339+
for (mlir::Value uda : targetDataOp.getUseDeviceAddrVars()) {
340+
if (uda.getDefiningOp() == mapOp)
341+
return true;
342+
}
343+
}
344+
return false;
345+
}
346+
347+
bool isUseDevicePtr(mlir::omp::MapInfoOp mapOp, mlir::Operation &userOp) {
348+
if (auto targetDataOp = llvm::dyn_cast<mlir::omp::TargetDataOp>(userOp)) {
349+
for (mlir::Value udp : targetDataOp.getUseDevicePtrVars()) {
350+
if (udp.getDefiningOp() == mapOp)
351+
return true;
352+
}
353+
}
354+
return false;
355+
}
356+
297357
mlir::omp::MapInfoOp genBoxcharMemberMap(mlir::omp::MapInfoOp op,
298358
fir::FirOpBuilder &builder) {
299359
if (!op.getMembers().empty())
@@ -358,12 +418,14 @@ class MapInfoFinalizationPass
358418

359419
// TODO: map the addendum segment of the descriptor, similarly to the
360420
// base address/data pointer member.
361-
mlir::Value descriptor = getDescriptorFromBoxMap(op, builder);
421+
bool descCanBeDeferred = false;
422+
mlir::Value descriptor =
423+
getDescriptorFromBoxMap(op, builder, descCanBeDeferred);
362424

363425
mlir::ArrayAttr newMembersAttr;
364426
mlir::SmallVector<mlir::Value> newMembers;
365427
llvm::SmallVector<llvm::SmallVector<int64_t>> memberIndices;
366-
bool IsHasDeviceAddr = isHasDeviceAddr(op, target);
428+
bool IsHasDeviceAddr = isHasDeviceAddr(op, *target);
367429

368430
if (!mapMemberUsers.empty() || !op.getMembers().empty())
369431
getMemberIndicesAsVectors(
@@ -445,6 +507,10 @@ class MapInfoFinalizationPass
445507
/*partial_map=*/builder.getBoolAttr(false));
446508
op.replaceAllUsesWith(newDescParentMapOp.getResult());
447509
op->erase();
510+
511+
if (descCanBeDeferred)
512+
deferrableDesc.push_back(newDescParentMapOp);
513+
448514
return newDescParentMapOp;
449515
}
450516

@@ -593,6 +659,124 @@ class MapInfoFinalizationPass
593659
return nullptr;
594660
}
595661

662+
void addImplicitDescriptorMapToTargetDataOp(mlir::omp::MapInfoOp op,
663+
fir::FirOpBuilder &builder,
664+
mlir::Operation &target) {
665+
// Checks if the map is present as an explicit map already on the target
666+
// data directive, and not just present on a use_device_addr/ptr, as if
667+
// that's the case, we should not need to add an implicit map for the
668+
// descriptor.
669+
auto explicitMappingPresent = [](mlir::omp::MapInfoOp op,
670+
mlir::omp::TargetDataOp tarData) {
671+
// Verify top-level descriptor mapping is at least equal with same
672+
// varPtr, the map type should always be To for a descriptor, which is
673+
// all we really care about for this mapping as we aim to make sure the
674+
// descriptor is always present on device if we're expecting to access
675+
// the underlying data.
676+
if (tarData.getMapVars().empty())
677+
return false;
678+
679+
for (mlir::Value mapVar : tarData.getMapVars()) {
680+
auto mapOp = llvm::cast<mlir::omp::MapInfoOp>(mapVar.getDefiningOp());
681+
if (mapOp.getVarPtr() == op.getVarPtr() &&
682+
mapOp.getVarPtrPtr() == op.getVarPtrPtr()) {
683+
return true;
684+
}
685+
}
686+
687+
return false;
688+
};
689+
690+
// if we're not a top level descriptor with members (e.g. member of a
691+
// derived type), we do not want to perform this step.
692+
if (!llvm::isa<mlir::omp::TargetDataOp>(target) || op.getMembers().empty())
693+
return;
694+
695+
if (!isUseDeviceAddr(op, target) && !isUseDevicePtr(op, target))
696+
return;
697+
698+
auto targetDataOp = llvm::cast<mlir::omp::TargetDataOp>(target);
699+
if (explicitMappingPresent(op, targetDataOp))
700+
return;
701+
702+
mlir::omp::MapInfoOp newDescParentMapOp =
703+
builder.create<mlir::omp::MapInfoOp>(
704+
op->getLoc(), op.getResult().getType(), op.getVarPtr(),
705+
op.getVarTypeAttr(),
706+
builder.getIntegerAttr(
707+
builder.getIntegerType(64, false),
708+
llvm::to_underlying(
709+
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO |
710+
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS)),
711+
op.getMapCaptureTypeAttr(), /*varPtrPtr=*/mlir::Value{},
712+
mlir::SmallVector<mlir::Value>{}, mlir::ArrayAttr{},
713+
/*bounds=*/mlir::SmallVector<mlir::Value>{},
714+
/*mapperId*/ mlir::FlatSymbolRefAttr(), op.getNameAttr(),
715+
/*partial_map=*/builder.getBoolAttr(false));
716+
717+
targetDataOp.getMapVarsMutable().append({newDescParentMapOp});
718+
}
719+
720+
void removeTopLevelDescriptor(mlir::omp::MapInfoOp op,
721+
fir::FirOpBuilder &builder,
722+
mlir::Operation *target) {
723+
if (llvm::isa<mlir::omp::TargetOp, mlir::omp::TargetDataOp,
724+
mlir::omp::DeclareMapperInfoOp>(target))
725+
return;
726+
727+
// if we're not a top level descriptor with members (e.g. member of a
728+
// derived type), we do not want to perform this step.
729+
if (op.getMembers().empty())
730+
return;
731+
732+
mlir::SmallVector<mlir::Value> members = op.getMembers();
733+
mlir::omp::MapInfoOp baseAddr =
734+
mlir::dyn_cast_or_null<mlir::omp::MapInfoOp>(
735+
members.front().getDefiningOp());
736+
assert(baseAddr && "Expected member to be MapInfoOp");
737+
members.erase(members.begin());
738+
739+
llvm::SmallVector<llvm::SmallVector<int64_t>> memberIndices;
740+
getMemberIndicesAsVectors(op, memberIndices);
741+
742+
// Can skip the extra processing if there's only 1 member as it'd
743+
// be the base addresses, which we're promoting to the parent.
744+
mlir::ArrayAttr membersAttr;
745+
if (memberIndices.size() > 1) {
746+
memberIndices.erase(memberIndices.begin());
747+
membersAttr = builder.create2DI64ArrayAttr(memberIndices);
748+
}
749+
750+
// VarPtrPtr is tied to detecting if something is a pointer in the later
751+
// lowering currently, this at the moment comes tied with
752+
// OMP_MAP_PTR_AND_OBJ being applied which breaks the problem this tries to
753+
// solve by emitting a 8-byte mapping tied to the descriptor address (even
754+
// if we only emit a single map). So we circumvent this by removing the
755+
// varPtrPtr mapping, however, a side affect of this is we lose the
756+
// additional load from the backend tied to this which is required for
757+
// correctness and getting the correct address of the data to perform our
758+
// mapping. So we do our load at this stage.
759+
// TODO/FIXME: Tidy up the OMP_MAP_PTR_AND_OBJ and varPtrPtr being tied to
760+
// if something is a pointer to try and tidy up the implementation a bit.
761+
// This is an unfortunate complexity from push-back from upstream. We
762+
// could also emit a load at this level for all base addresses as well,
763+
// which in turn will simplify the later lowering a bit as well. But first
764+
// need to see how well this alteration works.
765+
auto loadBaseAddr =
766+
builder.loadIfRef(op->getLoc(), baseAddr.getVarPtrPtr());
767+
mlir::omp::MapInfoOp newBaseAddrMapOp =
768+
builder.create<mlir::omp::MapInfoOp>(
769+
op->getLoc(), loadBaseAddr.getType(), loadBaseAddr,
770+
baseAddr.getVarTypeAttr(), baseAddr.getMapTypeAttr(),
771+
baseAddr.getMapCaptureTypeAttr(), mlir::Value{}, members,
772+
membersAttr, baseAddr.getBounds(),
773+
/*mapperId*/ mlir::FlatSymbolRefAttr(), op.getNameAttr(),
774+
/*partial_map=*/builder.getBoolAttr(false));
775+
op.replaceAllUsesWith(newBaseAddrMapOp.getResult());
776+
op->erase();
777+
baseAddr.erase();
778+
}
779+
596780
// This pass executes on omp::MapInfoOp's containing descriptor based types
597781
// (allocatables, pointers, assumed shape etc.) and expanding them into
598782
// multiple omp::MapInfoOp's for each pointer member contained within the
@@ -622,6 +806,7 @@ class MapInfoFinalizationPass
622806
// clear all local allocations we made for any boxes in any prior
623807
// iterations from previous function scopes.
624808
localBoxAllocas.clear();
809+
deferrableDesc.clear();
625810

626811
// First, walk `omp.map.info` ops to see if any of them have varPtrs
627812
// with an underlying type of fir.char<k, ?>, i.e a character
@@ -864,6 +1049,36 @@ class MapInfoFinalizationPass
8641049
}
8651050
});
8661051

1052+
// Now that we've expanded all of our boxes into a descriptor and base
1053+
// address map where necessary, we check if the map owner is an
1054+
// enter/exit/target data directive, and if they are we drop the initial
1055+
// descriptor (top-level parent) and replace it with the
1056+
// base_address/data.
1057+
//
1058+
// This circumvents issues with stack allocated descriptors bound to
1059+
// device colliding which in Flang is rather trivial for a user to do by
1060+
// accident due to the rather pervasive local intermediate descriptor
1061+
// generation that occurs whenever you pass boxes around different scopes.
1062+
// In OpenMP 6+ mapping these would be a user error as the tools required
1063+
// to circumvent these issues are provided by the spec (ref_ptr/ptee map
1064+
// types), but in prior specifications these tools are not available and
1065+
// it becomes an implementation issue for us to solve.
1066+
//
1067+
// We do this by dropping the top-level descriptor which will be the stack
1068+
// descriptor when we perform enter/exit maps, as we don't want these to
1069+
// be bound until necessary which is when we utilise the descriptor type
1070+
// within a target region. At which point we map the relevant descriptor
1071+
// data and the runtime should correctly associate the data with the
1072+
// descriptor and bind together and allow clean mapping and execution.
1073+
for (auto *op : deferrableDesc) {
1074+
auto mapOp = llvm::dyn_cast<mlir::omp::MapInfoOp>(op);
1075+
mlir::Operation *targetUser = getFirstTargetUser(mapOp);
1076+
assert(targetUser && "expected user of map operation was not found");
1077+
builder.setInsertionPoint(mapOp);
1078+
removeTopLevelDescriptor(mapOp, builder, targetUser);
1079+
addImplicitDescriptorMapToTargetDataOp(mapOp, builder, *targetUser);
1080+
}
1081+
8671082
// Wait until after we have generated all of our maps to add them onto
8681083
// the target's block arguments, simplifying the process as there would be
8691084
// no need to avoid accidental duplicate additions.

0 commit comments

Comments
 (0)