Skip to content

Commit d556ab5

Browse files
committed
[Flang][OpenMP] Fix for regression of smoke test flang-325095
When it comes to assumed size/shape/allocatables/anything tied to a descriptor that's used in a use_addr/ptr of a target data we need to emit an implicit map of the descriptor for the directive to be sure the descriptor data is also on device for the subsequent target data, as we're going to effectively be generating an access to the device descriptor for the duration of the target region, not just to access the data but also to generate accesses to the data via the descriptor bounds information, if it's not present, we don't perform the correct data accesses, which can result in writing to only the first index of the data or even not writing the data to any index. So we need to be sure we have the apporpriate descriptor information on device. This PR does that by adding a descriptor implicit map addition to MapInfoFinalization for target data when an descriptor type is found in a use_device_addr/ptr clause, and then a modification to the later lowering in OpenMPtoLLVMIRTranslation to make the duplicate map check a little more restrictive, to prevent it miscategorizing the regular descriptor map as a duplicate mapping to the use_device_addr/ptr map, as they aren't the same, one is just the descriptor, the other is the full descriptor/data member map done for use_device_ptr/addr (so not mapping the data). It might be possible to do it differently by "merging" the closely aligned entries, and making sure we don't lose the member maps and adjust the map types to reflect a combination of the solo descriptor map and use_device_addr/ptr map but this might be a bit more complicated and error prone versus tightening up the restriction on what's considered a duplicate entry.
1 parent 73aa7a0 commit d556ab5

File tree

3 files changed

+115
-30
lines changed

3 files changed

+115
-30
lines changed

flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -313,6 +313,28 @@ class MapInfoFinalizationPass
313313
return false;
314314
}
315315

316+
bool isUseDeviceAddr(mlir::omp::MapInfoOp mapOp, mlir::Operation *userOp) {
317+
assert(userOp && "Expecting non-null argument");
318+
if (auto targetDataOp = llvm::dyn_cast<mlir::omp::TargetDataOp>(userOp)) {
319+
for (mlir::Value uda : targetDataOp.getUseDeviceAddrVars()) {
320+
if (uda.getDefiningOp() == mapOp)
321+
return true;
322+
}
323+
}
324+
return false;
325+
}
326+
327+
bool isUseDevicePtr(mlir::omp::MapInfoOp mapOp, mlir::Operation *userOp) {
328+
assert(userOp && "Expecting non-null argument");
329+
if (auto targetDataOp = llvm::dyn_cast<mlir::omp::TargetDataOp>(userOp)) {
330+
for (mlir::Value udp : targetDataOp.getUseDevicePtrVars()) {
331+
if (udp.getDefiningOp() == mapOp)
332+
return true;
333+
}
334+
}
335+
return false;
336+
}
337+
316338
mlir::omp::MapInfoOp genDescriptorMemberMaps(mlir::omp::MapInfoOp op,
317339
fir::FirOpBuilder &builder,
318340
mlir::Operation *target) {
@@ -544,6 +566,66 @@ class MapInfoFinalizationPass
544566
return nullptr;
545567
}
546568

569+
void addImplictDescriptorMapToTargetDataOp(mlir::omp::MapInfoOp op,
570+
fir::FirOpBuilder &builder,
571+
mlir::Operation *target) {
572+
// Checks if the map is present as an explicit map already on the target
573+
// data directive, and not just present on a use_device_addr/ptr, as if
574+
// that's the case, we should not need to add an implicit map for the
575+
// descriptor.
576+
auto explicitMappingPresent = [](mlir::omp::MapInfoOp op,
577+
mlir::omp::TargetDataOp tarData) {
578+
// Verify top-level descriptor mapping is at least equal with same
579+
// varPtr, the map type should always be To for a descriptor, which is
580+
// all we really care about for this mapping as we aim to make sure the
581+
// descriptor is always present on device if we're expecting to access
582+
// the underlying data.
583+
if (tarData.getMapVars().empty())
584+
return false;
585+
586+
for (mlir::Value mapVar : tarData.getMapVars()) {
587+
auto mapOp =
588+
llvm::dyn_cast<mlir::omp::MapInfoOp>(mapVar.getDefiningOp());
589+
if (mapOp.getVarPtr() == op.getVarPtr() &&
590+
mapOp.getVarPtrPtr() == mapOp.getVarPtrPtr()) {
591+
return true;
592+
}
593+
}
594+
595+
return false;
596+
};
597+
598+
// if we're not a top level descriptor with members (e.g. member of a
599+
// derived type), we do not want to perform this step.
600+
if (!llvm::isa<mlir::omp::TargetDataOp>(target) || op.getMembers().empty())
601+
return;
602+
603+
if (!isUseDeviceAddr(op, target) && !isUseDevicePtr(op, target))
604+
return;
605+
606+
auto targetDataOp = llvm::dyn_cast<mlir::omp::TargetDataOp>(target);
607+
if (explicitMappingPresent(op, targetDataOp))
608+
return;
609+
610+
mlir::omp::MapInfoOp newDescParentMapOp =
611+
builder.create<mlir::omp::MapInfoOp>(
612+
op->getLoc(), op.getResult().getType(), op.getVarPtr(),
613+
op.getVarTypeAttr(),
614+
builder.getIntegerAttr(
615+
builder.getIntegerType(64, false),
616+
llvm::to_underlying(
617+
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO |
618+
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS |
619+
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DESCRIPTOR)),
620+
op.getMapCaptureTypeAttr(), /*varPtrPtr=*/mlir::Value{},
621+
mlir::SmallVector<mlir::Value>{}, mlir::ArrayAttr{},
622+
/*bounds=*/mlir::SmallVector<mlir::Value>{},
623+
/*mapperId*/ mlir::FlatSymbolRefAttr(), op.getNameAttr(),
624+
/*partial_map=*/builder.getBoolAttr(false));
625+
626+
targetDataOp.getMapVarsMutable().append({newDescParentMapOp});
627+
}
628+
547629
void removeTopLevelDescriptor(mlir::omp::MapInfoOp op,
548630
fir::FirOpBuilder &builder,
549631
mlir::Operation *target) {
@@ -880,6 +962,7 @@ class MapInfoFinalizationPass
880962
assert(targetUser && "expected user of map operation was not found");
881963
builder.setInsertionPoint(op);
882964
removeTopLevelDescriptor(op, builder, targetUser);
965+
addImplictDescriptorMapToTargetDataOp(op, builder, targetUser);
883966
}
884967
});
885968

flang/test/Lower/OpenMP/use-device-ptr-to-use-device-addr.f90

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
!CHECK: func.func @{{.*}}only_use_device_ptr()
99

10-
!CHECK: omp.target_data use_device_addr(%{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}} : !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>, !fir.llvm_ptr<!fir.ref<!fir.array<?xf32>>>, !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>) use_device_ptr(%{{.*}} -> %{{.*}} : !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>) {
10+
!CHECK: omp.target_data map_entries(%{{.*}}, %{{.*}} : !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>) use_device_addr(%{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}} : !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>, !fir.llvm_ptr<!fir.ref<!fir.array<?xf32>>>, !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>) use_device_ptr(%{{.*}} -> %{{.*}} : !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>) {
1111
subroutine only_use_device_ptr
1212
use iso_c_binding
1313
integer, pointer, dimension(:) :: array
@@ -19,7 +19,7 @@ subroutine only_use_device_ptr
1919
end subroutine
2020

2121
!CHECK: func.func @{{.*}}mix_use_device_ptr_and_addr()
22-
!CHECK: omp.target_data use_device_addr(%{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}} : !fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>, !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>, !fir.llvm_ptr<!fir.ref<!fir.array<?xf32>>>) use_device_ptr({{.*}} : !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>) {
22+
!CHECK: omp.target_data map_entries(%{{.*}}, %{{.*}} : !fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>) use_device_addr(%{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}} : !fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>, !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>, !fir.llvm_ptr<!fir.ref<!fir.array<?xf32>>>) use_device_ptr({{.*}} : !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>) {
2323
subroutine mix_use_device_ptr_and_addr
2424
use iso_c_binding
2525
integer, pointer, dimension(:) :: array
@@ -30,38 +30,38 @@ subroutine mix_use_device_ptr_and_addr
3030
!$omp end target data
3131
end subroutine
3232

33-
!CHECK: func.func @{{.*}}only_use_device_addr()
34-
!CHECK: omp.target_data use_device_addr(%{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}} : !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>, !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>, !fir.llvm_ptr<!fir.ref<!fir.array<?xf32>>>, !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>) {
35-
subroutine only_use_device_addr
36-
use iso_c_binding
37-
integer, pointer, dimension(:) :: array
38-
real, pointer :: pa(:)
39-
type(c_ptr) :: cptr
33+
!CHECK: func.func @{{.*}}only_use_device_addr()
34+
!CHECK: omp.target_data map_entries(%{{.*}}, %{{.*}} : !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>) use_device_addr(%{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}} : !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>, !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>, !fir.llvm_ptr<!fir.ref<!fir.array<?xf32>>>, !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>) {
35+
subroutine only_use_device_addr
36+
use iso_c_binding
37+
integer, pointer, dimension(:) :: array
38+
real, pointer :: pa(:)
39+
type(c_ptr) :: cptr
4040

4141
!$omp target data use_device_addr(pa, cptr, array)
4242
!$omp end target data
4343
end subroutine
4444

45-
!CHECK: func.func @{{.*}}mix_use_device_ptr_and_addr_and_map()
46-
!CHECK: omp.target_data map_entries(%{{.*}}, %{{.*}} : !fir.ref<i32>, !fir.ref<i32>) use_device_addr(%{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}} : !fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>, !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>, !fir.llvm_ptr<!fir.ref<!fir.array<?xf32>>>) use_device_ptr(%{{.*}} : !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>) {
47-
subroutine mix_use_device_ptr_and_addr_and_map
48-
use iso_c_binding
49-
integer :: i, j
50-
integer, pointer, dimension(:) :: array
51-
real, pointer :: pa(:)
52-
type(c_ptr) :: cptr
45+
!CHECK: func.func @{{.*}}mix_use_device_ptr_and_addr_and_map()
46+
!CHECK: omp.target_data map_entries(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : !fir.ref<i32>, !fir.ref<i32>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>) use_device_addr(%{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}} : !fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>, !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>, !fir.llvm_ptr<!fir.ref<!fir.array<?xf32>>>) use_device_ptr(%{{.*}} : !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>) {
47+
subroutine mix_use_device_ptr_and_addr_and_map
48+
use iso_c_binding
49+
integer :: i, j
50+
integer, pointer, dimension(:) :: array
51+
real, pointer :: pa(:)
52+
type(c_ptr) :: cptr
5353

5454
!$omp target data use_device_ptr(pa, cptr) use_device_addr(array) map(tofrom: i, j)
5555
!$omp end target data
5656
end subroutine
5757

58-
!CHECK: func.func @{{.*}}only_use_map()
59-
!CHECK: omp.target_data map_entries(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>, !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>, !fir.llvm_ptr<!fir.ref<!fir.array<?xf32>>>, !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>) {
60-
subroutine only_use_map
61-
use iso_c_binding
62-
integer, pointer, dimension(:) :: array
63-
real, pointer :: pa(:)
64-
type(c_ptr) :: cptr
58+
!CHECK: func.func @{{.*}}only_use_map()
59+
!CHECK: omp.target_data map_entries(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>, !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>, !fir.llvm_ptr<!fir.ref<!fir.array<?xf32>>>, !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>) {
60+
subroutine only_use_map
61+
use iso_c_binding
62+
integer, pointer, dimension(:) :: array
63+
real, pointer :: pa(:)
64+
type(c_ptr) :: cptr
6565

6666
!$omp target data map(pa, cptr, array)
6767
!$omp end target data

mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3723,17 +3723,19 @@ static void collectMapDataFromMapOperands(
37233723
}
37243724

37253725
auto findMapInfo = [&mapData](llvm::Value *val,
3726-
llvm::OpenMPIRBuilder::DeviceInfoTy devInfoTy) {
3727-
unsigned index = 0;
3726+
llvm::OpenMPIRBuilder::DeviceInfoTy devInfoTy,
3727+
size_t memberCount) {
37283728
bool found = false;
3729-
for (llvm::Value *basePtr : mapData.OriginalValue) {
3730-
if (basePtr == val && mapData.IsAMapping[index]) {
3729+
for (size_t index = 0; index < mapData.OriginalValue.size(); ++index) {
3730+
auto mapInfoOp = dyn_cast<omp::MapInfoOp>(mapData.MapClause[index]);
3731+
if (mapData.IsAMapping[index] && mapData.Pointers[index] == val &&
3732+
mapData.BasePointers[index] == val &&
3733+
memberCount == mapInfoOp.getMembers().size()) {
37313734
found = true;
37323735
mapData.Types[index] |=
37333736
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM;
37343737
mapData.DevicePointers[index] = devInfoTy;
37353738
}
3736-
index++;
37373739
}
37383740
return found;
37393741
};
@@ -3748,7 +3750,7 @@ static void collectMapDataFromMapOperands(
37483750
llvm::Value *origValue = moduleTranslation.lookupValue(offloadPtr);
37493751

37503752
// Check if map info is already present for this entry.
3751-
if (!findMapInfo(origValue, devInfoTy)) {
3753+
if (!findMapInfo(origValue, devInfoTy, mapOp.getMembers().size())) {
37523754
mapData.OriginalValue.push_back(origValue);
37533755
mapData.Pointers.push_back(mapData.OriginalValue.back());
37543756
mapData.IsDeclareTarget.push_back(false);

0 commit comments

Comments
 (0)