Skip to content

Commit e4f9b18

Browse files
authored
[Flang][OpenMP] Fix for regression of smoke test flang-325095 (llvm#2279)
2 parents 6731a07 + d556ab5 commit e4f9b18

File tree

3 files changed

+115
-30
lines changed

3 files changed

+115
-30
lines changed

flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -313,6 +313,28 @@ class MapInfoFinalizationPass
313313
return false;
314314
}
315315

316+
bool isUseDeviceAddr(mlir::omp::MapInfoOp mapOp, mlir::Operation *userOp) {
317+
assert(userOp && "Expecting non-null argument");
318+
if (auto targetDataOp = llvm::dyn_cast<mlir::omp::TargetDataOp>(userOp)) {
319+
for (mlir::Value uda : targetDataOp.getUseDeviceAddrVars()) {
320+
if (uda.getDefiningOp() == mapOp)
321+
return true;
322+
}
323+
}
324+
return false;
325+
}
326+
327+
bool isUseDevicePtr(mlir::omp::MapInfoOp mapOp, mlir::Operation *userOp) {
328+
assert(userOp && "Expecting non-null argument");
329+
if (auto targetDataOp = llvm::dyn_cast<mlir::omp::TargetDataOp>(userOp)) {
330+
for (mlir::Value udp : targetDataOp.getUseDevicePtrVars()) {
331+
if (udp.getDefiningOp() == mapOp)
332+
return true;
333+
}
334+
}
335+
return false;
336+
}
337+
316338
mlir::omp::MapInfoOp genDescriptorMemberMaps(mlir::omp::MapInfoOp op,
317339
fir::FirOpBuilder &builder,
318340
mlir::Operation *target) {
@@ -544,6 +566,66 @@ class MapInfoFinalizationPass
544566
return nullptr;
545567
}
546568

569+
void addImplictDescriptorMapToTargetDataOp(mlir::omp::MapInfoOp op,
570+
fir::FirOpBuilder &builder,
571+
mlir::Operation *target) {
572+
// Checks if the map is present as an explicit map already on the target
573+
// data directive, and not just present on a use_device_addr/ptr, as if
574+
// that's the case, we should not need to add an implicit map for the
575+
// descriptor.
576+
auto explicitMappingPresent = [](mlir::omp::MapInfoOp op,
577+
mlir::omp::TargetDataOp tarData) {
578+
// Verify top-level descriptor mapping is at least equal with same
579+
// varPtr, the map type should always be To for a descriptor, which is
580+
// all we really care about for this mapping as we aim to make sure the
581+
// descriptor is always present on device if we're expecting to access
582+
// the underlying data.
583+
if (tarData.getMapVars().empty())
584+
return false;
585+
586+
for (mlir::Value mapVar : tarData.getMapVars()) {
587+
auto mapOp =
588+
llvm::dyn_cast<mlir::omp::MapInfoOp>(mapVar.getDefiningOp());
589+
if (mapOp.getVarPtr() == op.getVarPtr() &&
590+
mapOp.getVarPtrPtr() == mapOp.getVarPtrPtr()) {
591+
return true;
592+
}
593+
}
594+
595+
return false;
596+
};
597+
598+
// if we're not a top level descriptor with members (e.g. member of a
599+
// derived type), we do not want to perform this step.
600+
if (!llvm::isa<mlir::omp::TargetDataOp>(target) || op.getMembers().empty())
601+
return;
602+
603+
if (!isUseDeviceAddr(op, target) && !isUseDevicePtr(op, target))
604+
return;
605+
606+
auto targetDataOp = llvm::dyn_cast<mlir::omp::TargetDataOp>(target);
607+
if (explicitMappingPresent(op, targetDataOp))
608+
return;
609+
610+
mlir::omp::MapInfoOp newDescParentMapOp =
611+
builder.create<mlir::omp::MapInfoOp>(
612+
op->getLoc(), op.getResult().getType(), op.getVarPtr(),
613+
op.getVarTypeAttr(),
614+
builder.getIntegerAttr(
615+
builder.getIntegerType(64, false),
616+
llvm::to_underlying(
617+
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO |
618+
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS |
619+
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DESCRIPTOR)),
620+
op.getMapCaptureTypeAttr(), /*varPtrPtr=*/mlir::Value{},
621+
mlir::SmallVector<mlir::Value>{}, mlir::ArrayAttr{},
622+
/*bounds=*/mlir::SmallVector<mlir::Value>{},
623+
/*mapperId*/ mlir::FlatSymbolRefAttr(), op.getNameAttr(),
624+
/*partial_map=*/builder.getBoolAttr(false));
625+
626+
targetDataOp.getMapVarsMutable().append({newDescParentMapOp});
627+
}
628+
547629
void removeTopLevelDescriptor(mlir::omp::MapInfoOp op,
548630
fir::FirOpBuilder &builder,
549631
mlir::Operation *target) {
@@ -880,6 +962,7 @@ class MapInfoFinalizationPass
880962
assert(targetUser && "expected user of map operation was not found");
881963
builder.setInsertionPoint(op);
882964
removeTopLevelDescriptor(op, builder, targetUser);
965+
addImplictDescriptorMapToTargetDataOp(op, builder, targetUser);
883966
}
884967
});
885968

flang/test/Lower/OpenMP/use-device-ptr-to-use-device-addr.f90

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
!CHECK: func.func @{{.*}}only_use_device_ptr()
99

10-
!CHECK: omp.target_data use_device_addr(%{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}} : !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>, !fir.llvm_ptr<!fir.ref<!fir.array<?xf32>>>, !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>) use_device_ptr(%{{.*}} -> %{{.*}} : !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>) {
10+
!CHECK: omp.target_data map_entries(%{{.*}}, %{{.*}} : !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>) use_device_addr(%{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}} : !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>, !fir.llvm_ptr<!fir.ref<!fir.array<?xf32>>>, !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>) use_device_ptr(%{{.*}} -> %{{.*}} : !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>) {
1111
subroutine only_use_device_ptr
1212
use iso_c_binding
1313
integer, pointer, dimension(:) :: array
@@ -19,7 +19,7 @@ subroutine only_use_device_ptr
1919
end subroutine
2020

2121
!CHECK: func.func @{{.*}}mix_use_device_ptr_and_addr()
22-
!CHECK: omp.target_data use_device_addr(%{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}} : !fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>, !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>, !fir.llvm_ptr<!fir.ref<!fir.array<?xf32>>>) use_device_ptr({{.*}} : !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>) {
22+
!CHECK: omp.target_data map_entries(%{{.*}}, %{{.*}} : !fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>) use_device_addr(%{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}} : !fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>, !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>, !fir.llvm_ptr<!fir.ref<!fir.array<?xf32>>>) use_device_ptr({{.*}} : !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>) {
2323
subroutine mix_use_device_ptr_and_addr
2424
use iso_c_binding
2525
integer, pointer, dimension(:) :: array
@@ -30,38 +30,38 @@ subroutine mix_use_device_ptr_and_addr
3030
!$omp end target data
3131
end subroutine
3232

33-
!CHECK: func.func @{{.*}}only_use_device_addr()
34-
!CHECK: omp.target_data use_device_addr(%{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}} : !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>, !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>, !fir.llvm_ptr<!fir.ref<!fir.array<?xf32>>>, !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>) {
35-
subroutine only_use_device_addr
36-
use iso_c_binding
37-
integer, pointer, dimension(:) :: array
38-
real, pointer :: pa(:)
39-
type(c_ptr) :: cptr
33+
!CHECK: func.func @{{.*}}only_use_device_addr()
34+
!CHECK: omp.target_data map_entries(%{{.*}}, %{{.*}} : !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>) use_device_addr(%{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}} : !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>, !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>, !fir.llvm_ptr<!fir.ref<!fir.array<?xf32>>>, !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>) {
35+
subroutine only_use_device_addr
36+
use iso_c_binding
37+
integer, pointer, dimension(:) :: array
38+
real, pointer :: pa(:)
39+
type(c_ptr) :: cptr
4040

4141
!$omp target data use_device_addr(pa, cptr, array)
4242
!$omp end target data
4343
end subroutine
4444

45-
!CHECK: func.func @{{.*}}mix_use_device_ptr_and_addr_and_map()
46-
!CHECK: omp.target_data map_entries(%{{.*}}, %{{.*}} : !fir.ref<i32>, !fir.ref<i32>) use_device_addr(%{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}} : !fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>, !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>, !fir.llvm_ptr<!fir.ref<!fir.array<?xf32>>>) use_device_ptr(%{{.*}} : !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>) {
47-
subroutine mix_use_device_ptr_and_addr_and_map
48-
use iso_c_binding
49-
integer :: i, j
50-
integer, pointer, dimension(:) :: array
51-
real, pointer :: pa(:)
52-
type(c_ptr) :: cptr
45+
!CHECK: func.func @{{.*}}mix_use_device_ptr_and_addr_and_map()
46+
!CHECK: omp.target_data map_entries(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : !fir.ref<i32>, !fir.ref<i32>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>) use_device_addr(%{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}} : !fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>, !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>, !fir.llvm_ptr<!fir.ref<!fir.array<?xf32>>>) use_device_ptr(%{{.*}} : !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>) {
47+
subroutine mix_use_device_ptr_and_addr_and_map
48+
use iso_c_binding
49+
integer :: i, j
50+
integer, pointer, dimension(:) :: array
51+
real, pointer :: pa(:)
52+
type(c_ptr) :: cptr
5353

5454
!$omp target data use_device_ptr(pa, cptr) use_device_addr(array) map(tofrom: i, j)
5555
!$omp end target data
5656
end subroutine
5757

58-
!CHECK: func.func @{{.*}}only_use_map()
59-
!CHECK: omp.target_data map_entries(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>, !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>, !fir.llvm_ptr<!fir.ref<!fir.array<?xf32>>>, !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>) {
60-
subroutine only_use_map
61-
use iso_c_binding
62-
integer, pointer, dimension(:) :: array
63-
real, pointer :: pa(:)
64-
type(c_ptr) :: cptr
58+
!CHECK: func.func @{{.*}}only_use_map()
59+
!CHECK: omp.target_data map_entries(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>, !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>, !fir.llvm_ptr<!fir.ref<!fir.array<?xf32>>>, !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>) {
60+
subroutine only_use_map
61+
use iso_c_binding
62+
integer, pointer, dimension(:) :: array
63+
real, pointer :: pa(:)
64+
type(c_ptr) :: cptr
6565

6666
!$omp target data map(pa, cptr, array)
6767
!$omp end target data

mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3543,17 +3543,19 @@ static void collectMapDataFromMapOperands(
35433543
}
35443544

35453545
auto findMapInfo = [&mapData](llvm::Value *val,
3546-
llvm::OpenMPIRBuilder::DeviceInfoTy devInfoTy) {
3547-
unsigned index = 0;
3546+
llvm::OpenMPIRBuilder::DeviceInfoTy devInfoTy,
3547+
size_t memberCount) {
35483548
bool found = false;
3549-
for (llvm::Value *basePtr : mapData.OriginalValue) {
3550-
if (basePtr == val && mapData.IsAMapping[index]) {
3549+
for (size_t index = 0; index < mapData.OriginalValue.size(); ++index) {
3550+
auto mapInfoOp = dyn_cast<omp::MapInfoOp>(mapData.MapClause[index]);
3551+
if (mapData.IsAMapping[index] && mapData.Pointers[index] == val &&
3552+
mapData.BasePointers[index] == val &&
3553+
memberCount == mapInfoOp.getMembers().size()) {
35513554
found = true;
35523555
mapData.Types[index] |=
35533556
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM;
35543557
mapData.DevicePointers[index] = devInfoTy;
35553558
}
3556-
index++;
35573559
}
35583560
return found;
35593561
};
@@ -3568,7 +3570,7 @@ static void collectMapDataFromMapOperands(
35683570
llvm::Value *origValue = moduleTranslation.lookupValue(offloadPtr);
35693571

35703572
// Check if map info is already present for this entry.
3571-
if (!findMapInfo(origValue, devInfoTy)) {
3573+
if (!findMapInfo(origValue, devInfoTy, mapOp.getMembers().size())) {
35723574
mapData.OriginalValue.push_back(origValue);
35733575
mapData.Pointers.push_back(mapData.OriginalValue.back());
35743576
mapData.IsDeclareTarget.push_back(false);

0 commit comments

Comments
 (0)