Skip to content

Commit 3987e87

Browse files
agozillonsvkeerthy
authored andcommitted
[Flang][OpenMP] Defer descriptor mapping for assumed dummy argument types (#154349)
This PR adds deferral of descriptor maps until they are necessary for assumed dummy argument types. The intent is to avoid a problem where a user can inadvertently map a temporary local descriptor to device without their knowledge and proceed to never unmap it. This temporary local descriptor remains lodged in OpenMP device memory and the next time another variable or descriptor residing in the same stack address is mapped we incur a runtime OpenMP map error as we try to remap the same address. This fix was discussed with the OpenMP committee and applies to OpenMP 5.2 and below, future versions of OpenMP can avoid this issue via the attach semantics added to the specification.
1 parent 364fbbd commit 3987e87

File tree

4 files changed

+420
-8
lines changed

4 files changed

+420
-8
lines changed

flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp

Lines changed: 220 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,10 @@ class MapInfoFinalizationPass
7777
/// | |
7878
std::map<mlir::Operation *, mlir::Value> localBoxAllocas;
7979

80+
// List of deferrable descriptors to process at the end of
81+
// the pass.
82+
llvm::SmallVector<mlir::Operation *> deferrableDesc;
83+
8084
/// Return true if the given path exists in a list of paths.
8185
static bool
8286
containsPath(const llvm::SmallVectorImpl<llvm::SmallVector<int64_t>> &paths,
@@ -183,6 +187,40 @@ class MapInfoFinalizationPass
183187
newMemberIndexPaths.emplace_back(indexPath.begin(), indexPath.end());
184188
}
185189

190+
// Check if the declaration operation we have refers to a dummy
191+
// function argument.
192+
bool isDummyArgument(mlir::Value mappedValue) {
193+
if (auto declareOp = mlir::dyn_cast_if_present<hlfir::DeclareOp>(
194+
mappedValue.getDefiningOp()))
195+
if (auto dummyScope = declareOp.getDummyScope())
196+
return true;
197+
return false;
198+
}
199+
200+
// Relevant for OpenMP < 5.2, where attach semantics and rules don't exist.
201+
// As descriptors were an unspoken implementation detail in these versions
202+
// there's certain cases where the user (and the compiler implementation)
203+
// can create data mapping errors by having temporary descriptors stuck
204+
// in memory. The main example is calling an 'target enter data map'
205+
// without a corresponding exit on an assumed shape or size dummy
206+
// argument, a local stack descriptor is generated, gets mapped and
207+
// is then left on device. A user doesn't realize what they've done as
208+
// the OpenMP specification isn't explicit on descriptor handling in
209+
// earlier versions and as far as Fortran is concerned this si something
210+
// hidden from a user. To avoid this we can defer the descriptor mapping
211+
// in these cases until target or target data regions, when we can be
212+
// sure they have a clear limited scope on device.
213+
bool canDeferDescriptorMapping(mlir::Value descriptor) {
214+
if (fir::isAllocatableType(descriptor.getType()) ||
215+
fir::isPointerType(descriptor.getType()))
216+
return false;
217+
if (isDummyArgument(descriptor) &&
218+
(fir::isAssumedType(descriptor.getType()) ||
219+
fir::isAssumedShape(descriptor.getType())))
220+
return true;
221+
return false;
222+
}
223+
186224
/// getMemberUserList gathers all users of a particular MapInfoOp that are
187225
/// other MapInfoOp's and places them into the mapMemberUsers list, which
188226
/// records the map that the current argument MapInfoOp "op" is part of
@@ -234,13 +272,16 @@ class MapInfoFinalizationPass
234272
/// fir::BoxOffsetOp we utilise to access the descriptor datas
235273
/// base address can be utilised.
236274
mlir::Value getDescriptorFromBoxMap(mlir::omp::MapInfoOp boxMap,
237-
fir::FirOpBuilder &builder) {
275+
fir::FirOpBuilder &builder,
276+
bool &canDescBeDeferred) {
238277
mlir::Value descriptor = boxMap.getVarPtr();
239278
if (!fir::isTypeWithDescriptor(boxMap.getVarType()))
240279
if (auto addrOp = mlir::dyn_cast_if_present<fir::BoxAddrOp>(
241280
boxMap.getVarPtr().getDefiningOp()))
242281
descriptor = addrOp.getVal();
243282

283+
canDescBeDeferred = canDeferDescriptorMapping(descriptor);
284+
244285
if (!mlir::isa<fir::BaseBoxType>(descriptor.getType()) &&
245286
!fir::factory::isOptionalArgument(descriptor.getDefiningOp()))
246287
return descriptor;
@@ -391,8 +432,7 @@ class MapInfoFinalizationPass
391432

392433
/// Check if the mapOp is present in the HasDeviceAddr clause on
393434
/// the userOp. Only applies to TargetOp.
394-
bool isHasDeviceAddr(mlir::omp::MapInfoOp mapOp, mlir::Operation *userOp) {
395-
assert(userOp && "Expecting non-null argument");
435+
bool isHasDeviceAddr(mlir::omp::MapInfoOp mapOp, mlir::Operation &userOp) {
396436
if (auto targetOp = llvm::dyn_cast<mlir::omp::TargetOp>(userOp)) {
397437
for (mlir::Value hda : targetOp.getHasDeviceAddrVars()) {
398438
if (hda.getDefiningOp() == mapOp)
@@ -402,6 +442,26 @@ class MapInfoFinalizationPass
402442
return false;
403443
}
404444

445+
bool isUseDeviceAddr(mlir::omp::MapInfoOp mapOp, mlir::Operation &userOp) {
446+
if (auto targetDataOp = llvm::dyn_cast<mlir::omp::TargetDataOp>(userOp)) {
447+
for (mlir::Value uda : targetDataOp.getUseDeviceAddrVars()) {
448+
if (uda.getDefiningOp() == mapOp)
449+
return true;
450+
}
451+
}
452+
return false;
453+
}
454+
455+
bool isUseDevicePtr(mlir::omp::MapInfoOp mapOp, mlir::Operation &userOp) {
456+
if (auto targetDataOp = llvm::dyn_cast<mlir::omp::TargetDataOp>(userOp)) {
457+
for (mlir::Value udp : targetDataOp.getUseDevicePtrVars()) {
458+
if (udp.getDefiningOp() == mapOp)
459+
return true;
460+
}
461+
}
462+
return false;
463+
}
464+
405465
mlir::omp::MapInfoOp genBoxcharMemberMap(mlir::omp::MapInfoOp op,
406466
fir::FirOpBuilder &builder) {
407467
if (!op.getMembers().empty())
@@ -466,12 +526,14 @@ class MapInfoFinalizationPass
466526

467527
// TODO: map the addendum segment of the descriptor, similarly to the
468528
// base address/data pointer member.
469-
mlir::Value descriptor = getDescriptorFromBoxMap(op, builder);
529+
bool descCanBeDeferred = false;
530+
mlir::Value descriptor =
531+
getDescriptorFromBoxMap(op, builder, descCanBeDeferred);
470532

471533
mlir::ArrayAttr newMembersAttr;
472534
mlir::SmallVector<mlir::Value> newMembers;
473535
llvm::SmallVector<llvm::SmallVector<int64_t>> memberIndices;
474-
bool isHasDeviceAddrFlag = isHasDeviceAddr(op, target);
536+
bool isHasDeviceAddrFlag = isHasDeviceAddr(op, *target);
475537

476538
if (!mapMemberUsers.empty() || !op.getMembers().empty())
477539
getMemberIndicesAsVectors(
@@ -553,6 +615,10 @@ class MapInfoFinalizationPass
553615
/*partial_map=*/builder.getBoolAttr(false));
554616
op.replaceAllUsesWith(newDescParentMapOp.getResult());
555617
op->erase();
618+
619+
if (descCanBeDeferred)
620+
deferrableDesc.push_back(newDescParentMapOp);
621+
556622
return newDescParentMapOp;
557623
}
558624

@@ -701,6 +767,124 @@ class MapInfoFinalizationPass
701767
return nullptr;
702768
}
703769

770+
void addImplicitDescriptorMapToTargetDataOp(mlir::omp::MapInfoOp op,
771+
fir::FirOpBuilder &builder,
772+
mlir::Operation &target) {
773+
// Checks if the map is present as an explicit map already on the target
774+
// data directive, and not just present on a use_device_addr/ptr, as if
775+
// that's the case, we should not need to add an implicit map for the
776+
// descriptor.
777+
auto explicitMappingPresent = [](mlir::omp::MapInfoOp op,
778+
mlir::omp::TargetDataOp tarData) {
779+
// Verify top-level descriptor mapping is at least equal with same
780+
// varPtr, the map type should always be To for a descriptor, which is
781+
// all we really care about for this mapping as we aim to make sure the
782+
// descriptor is always present on device if we're expecting to access
783+
// the underlying data.
784+
if (tarData.getMapVars().empty())
785+
return false;
786+
787+
for (mlir::Value mapVar : tarData.getMapVars()) {
788+
auto mapOp = llvm::cast<mlir::omp::MapInfoOp>(mapVar.getDefiningOp());
789+
if (mapOp.getVarPtr() == op.getVarPtr() &&
790+
mapOp.getVarPtrPtr() == op.getVarPtrPtr()) {
791+
return true;
792+
}
793+
}
794+
795+
return false;
796+
};
797+
798+
// if we're not a top level descriptor with members (e.g. member of a
799+
// derived type), we do not want to perform this step.
800+
if (!llvm::isa<mlir::omp::TargetDataOp>(target) || op.getMembers().empty())
801+
return;
802+
803+
if (!isUseDeviceAddr(op, target) && !isUseDevicePtr(op, target))
804+
return;
805+
806+
auto targetDataOp = llvm::cast<mlir::omp::TargetDataOp>(target);
807+
if (explicitMappingPresent(op, targetDataOp))
808+
return;
809+
810+
mlir::omp::MapInfoOp newDescParentMapOp =
811+
builder.create<mlir::omp::MapInfoOp>(
812+
op->getLoc(), op.getResult().getType(), op.getVarPtr(),
813+
op.getVarTypeAttr(),
814+
builder.getIntegerAttr(
815+
builder.getIntegerType(64, false),
816+
llvm::to_underlying(
817+
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO |
818+
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS)),
819+
op.getMapCaptureTypeAttr(), /*varPtrPtr=*/mlir::Value{},
820+
mlir::SmallVector<mlir::Value>{}, mlir::ArrayAttr{},
821+
/*bounds=*/mlir::SmallVector<mlir::Value>{},
822+
/*mapperId*/ mlir::FlatSymbolRefAttr(), op.getNameAttr(),
823+
/*partial_map=*/builder.getBoolAttr(false));
824+
825+
targetDataOp.getMapVarsMutable().append({newDescParentMapOp});
826+
}
827+
828+
void removeTopLevelDescriptor(mlir::omp::MapInfoOp op,
829+
fir::FirOpBuilder &builder,
830+
mlir::Operation *target) {
831+
if (llvm::isa<mlir::omp::TargetOp, mlir::omp::TargetDataOp,
832+
mlir::omp::DeclareMapperInfoOp>(target))
833+
return;
834+
835+
// if we're not a top level descriptor with members (e.g. member of a
836+
// derived type), we do not want to perform this step.
837+
if (op.getMembers().empty())
838+
return;
839+
840+
mlir::SmallVector<mlir::Value> members = op.getMembers();
841+
mlir::omp::MapInfoOp baseAddr =
842+
mlir::dyn_cast_or_null<mlir::omp::MapInfoOp>(
843+
members.front().getDefiningOp());
844+
assert(baseAddr && "Expected member to be MapInfoOp");
845+
members.erase(members.begin());
846+
847+
llvm::SmallVector<llvm::SmallVector<int64_t>> memberIndices;
848+
getMemberIndicesAsVectors(op, memberIndices);
849+
850+
// Can skip the extra processing if there's only 1 member as it'd
851+
// be the base addresses, which we're promoting to the parent.
852+
mlir::ArrayAttr membersAttr;
853+
if (memberIndices.size() > 1) {
854+
memberIndices.erase(memberIndices.begin());
855+
membersAttr = builder.create2DI64ArrayAttr(memberIndices);
856+
}
857+
858+
// VarPtrPtr is tied to detecting if something is a pointer in the later
859+
// lowering currently, this at the moment comes tied with
860+
// OMP_MAP_PTR_AND_OBJ being applied which breaks the problem this tries to
861+
// solve by emitting a 8-byte mapping tied to the descriptor address (even
862+
// if we only emit a single map). So we circumvent this by removing the
863+
// varPtrPtr mapping, however, a side affect of this is we lose the
864+
// additional load from the backend tied to this which is required for
865+
// correctness and getting the correct address of the data to perform our
866+
// mapping. So we do our load at this stage.
867+
// TODO/FIXME: Tidy up the OMP_MAP_PTR_AND_OBJ and varPtrPtr being tied to
868+
// if something is a pointer to try and tidy up the implementation a bit.
869+
// This is an unfortunate complexity from push-back from upstream. We
870+
// could also emit a load at this level for all base addresses as well,
871+
// which in turn will simplify the later lowering a bit as well. But first
872+
// need to see how well this alteration works.
873+
auto loadBaseAddr =
874+
builder.loadIfRef(op->getLoc(), baseAddr.getVarPtrPtr());
875+
mlir::omp::MapInfoOp newBaseAddrMapOp =
876+
builder.create<mlir::omp::MapInfoOp>(
877+
op->getLoc(), loadBaseAddr.getType(), loadBaseAddr,
878+
baseAddr.getVarTypeAttr(), baseAddr.getMapTypeAttr(),
879+
baseAddr.getMapCaptureTypeAttr(), mlir::Value{}, members,
880+
membersAttr, baseAddr.getBounds(),
881+
/*mapperId*/ mlir::FlatSymbolRefAttr(), op.getNameAttr(),
882+
/*partial_map=*/builder.getBoolAttr(false));
883+
op.replaceAllUsesWith(newBaseAddrMapOp.getResult());
884+
op->erase();
885+
baseAddr.erase();
886+
}
887+
704888
// This pass executes on omp::MapInfoOp's containing descriptor based types
705889
// (allocatables, pointers, assumed shape etc.) and expanding them into
706890
// multiple omp::MapInfoOp's for each pointer member contained within the
@@ -730,6 +914,7 @@ class MapInfoFinalizationPass
730914
// clear all local allocations we made for any boxes in any prior
731915
// iterations from previous function scopes.
732916
localBoxAllocas.clear();
917+
deferrableDesc.clear();
733918

734919
// First, walk `omp.map.info` ops to see if any of them have varPtrs
735920
// with an underlying type of fir.char<k, ?>, i.e a character
@@ -1010,6 +1195,36 @@ class MapInfoFinalizationPass
10101195
}
10111196
});
10121197

1198+
// Now that we've expanded all of our boxes into a descriptor and base
1199+
// address map where necessary, we check if the map owner is an
1200+
// enter/exit/target data directive, and if they are we drop the initial
1201+
// descriptor (top-level parent) and replace it with the
1202+
// base_address/data.
1203+
//
1204+
// This circumvents issues with stack allocated descriptors bound to
1205+
// device colliding which in Flang is rather trivial for a user to do by
1206+
// accident due to the rather pervasive local intermediate descriptor
1207+
// generation that occurs whenever you pass boxes around different scopes.
1208+
// In OpenMP 6+ mapping these would be a user error as the tools required
1209+
// to circumvent these issues are provided by the spec (ref_ptr/ptee map
1210+
// types), but in prior specifications these tools are not available and
1211+
// it becomes an implementation issue for us to solve.
1212+
//
1213+
// We do this by dropping the top-level descriptor which will be the stack
1214+
// descriptor when we perform enter/exit maps, as we don't want these to
1215+
// be bound until necessary which is when we utilise the descriptor type
1216+
// within a target region. At which point we map the relevant descriptor
1217+
// data and the runtime should correctly associate the data with the
1218+
// descriptor and bind together and allow clean mapping and execution.
1219+
for (auto *op : deferrableDesc) {
1220+
auto mapOp = llvm::dyn_cast<mlir::omp::MapInfoOp>(op);
1221+
mlir::Operation *targetUser = getFirstTargetUser(mapOp);
1222+
assert(targetUser && "expected user of map operation was not found");
1223+
builder.setInsertionPoint(mapOp);
1224+
removeTopLevelDescriptor(mapOp, builder, targetUser);
1225+
addImplicitDescriptorMapToTargetDataOp(mapOp, builder, *targetUser);
1226+
}
1227+
10131228
// Wait until after we have generated all of our maps to add them onto
10141229
// the target's block arguments, simplifying the process as there would be
10151230
// no need to avoid accidental duplicate additions.

0 commit comments

Comments
 (0)