-
Notifications
You must be signed in to change notification settings - Fork 15.3k
[Flang][OpenMP] Update MapInfoFinalization to use BlockArgs Interface and modify use_device_ptr/addr to be order independent #113919
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-flang-openmp Author: None (agozillon) ChangesThis patch primarily updates the MapInfoFinalization pass to utilise the BlockArgument interface. It also shuffles newly added arguments the MapInfoFinalization passes to the end of the BlockArg/Relevant MapInfo lists, instead of one prior to the owning descriptor type. During this it was noted that the use_device_ptr/addr handling of target data was a little bit too order dependent so I've attempted to make it less so, as we cannot depend on argument ordering to be the same as Fortran for any future frontends. Patch is 35.79 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/113919.diff 10 Files Affected:
diff --git a/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp b/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp
index 7ebeb51cf3dec7..6eb65b25594295 100644
--- a/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp
+++ b/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp
@@ -125,61 +125,82 @@ class MapInfoFinalizationPass
// TODO: map the addendum segment of the descriptor, similarly to the
// above base address/data pointer member.
- auto addOperands = [&](mlir::OperandRange &operandsArr,
- mlir::MutableOperandRange &mutableOpRange,
- auto directiveOp) {
- llvm::SmallVector<mlir::Value> newMapOps;
- for (size_t i = 0; i < operandsArr.size(); ++i) {
- if (operandsArr[i] == op) {
- // Push new implicit maps generated for the descriptor.
- newMapOps.push_back(baseAddr);
+ mlir::omp::MapInfoOp newDescParentMapOp =
+ builder.create<mlir::omp::MapInfoOp>(
+ op->getLoc(), op.getResult().getType(), descriptor,
+ mlir::TypeAttr::get(fir::unwrapRefType(descriptor.getType())),
+ /*varPtrPtr=*/mlir::Value{},
+ /*members=*/mlir::SmallVector<mlir::Value>{baseAddr},
+ /*members_index=*/
+ mlir::DenseIntElementsAttr::get(
+ mlir::VectorType::get(
+ llvm::ArrayRef<int64_t>({1, 1}),
+ mlir::IntegerType::get(builder.getContext(), 32)),
+ llvm::ArrayRef<int32_t>({0})),
+ /*bounds=*/mlir::SmallVector<mlir::Value>{},
+ builder.getIntegerAttr(builder.getIntegerType(64, false),
+ op.getMapType().value()),
+ op.getMapCaptureTypeAttr(), op.getNameAttr(),
+ op.getPartialMapAttr());
+ op.replaceAllUsesWith(newDescParentMapOp.getResult());
+ op->erase();
- // for TargetOp's which have IsolatedFromAbove we must align the
- // new additional map operand with an appropriate BlockArgument,
- // as the printing and later processing currently requires a 1:1
- // mapping of BlockArgs to MapInfoOp's at the same placement in
- // each array (BlockArgs and MapOperands).
- if (directiveOp) {
- directiveOp.getRegion().insertArgument(i, baseAddr.getType(), loc);
+ auto addOperands = [&](mlir::OperandRange &mapVarsArr,
+ mlir::MutableOperandRange &mutableOpRange,
+ mlir::Operation *directiveOp,
+ mlir::omp::MapInfoOp newDesc,
+ unsigned blockArgInsertIndex = 0,
+ bool insertBlockArgs = true) {
+ if (llvm::is_contained(mapVarsArr, newDesc.getResult())) {
+ llvm::SmallVector<mlir::Value> newMapOps{mapVarsArr};
+ for (auto mapMember : newDesc.getMembers()) {
+ if (!llvm::is_contained(mapVarsArr, mapMember)) {
+ newMapOps.push_back(mapMember);
+ if (directiveOp && insertBlockArgs) {
+ directiveOp->getRegion(0).insertArgument(
+ blockArgInsertIndex, mapMember.getType(), mapMember.getLoc());
+ }
+ blockArgInsertIndex++;
}
}
- newMapOps.push_back(operandsArr[i]);
+ mutableOpRange.assign(newMapOps);
}
- mutableOpRange.assign(newMapOps);
};
+
+ auto argIface =
+ llvm::dyn_cast<mlir::omp::BlockArgOpenMPOpInterface>(target);
+
if (auto mapClauseOwner =
llvm::dyn_cast<mlir::omp::MapClauseOwningOpInterface>(target)) {
- mlir::OperandRange mapOperandsArr = mapClauseOwner.getMapVars();
+ mlir::OperandRange mapVarsArr = mapClauseOwner.getMapVars();
mlir::MutableOperandRange mapMutableOpRange =
mapClauseOwner.getMapVarsMutable();
- mlir::omp::TargetOp targetOp =
- llvm::dyn_cast<mlir::omp::TargetOp>(target);
- addOperands(mapOperandsArr, mapMutableOpRange, targetOp);
+ unsigned blockArgInsertIndex =
+ argIface
+ ? argIface.getMapBlockArgsStart() + argIface.numMapBlockArgs()
+ : 0;
+ addOperands(mapVarsArr, mapMutableOpRange, argIface.getOperation(),
+ newDescParentMapOp, blockArgInsertIndex,
+ !llvm::isa<mlir::omp::TargetDataOp>(target));
}
+
if (auto targetDataOp = llvm::dyn_cast<mlir::omp::TargetDataOp>(target)) {
mlir::OperandRange useDevAddrArr = targetDataOp.getUseDeviceAddrVars();
mlir::MutableOperandRange useDevAddrMutableOpRange =
targetDataOp.getUseDeviceAddrVarsMutable();
- addOperands(useDevAddrArr, useDevAddrMutableOpRange, targetDataOp);
- }
+ addOperands(useDevAddrArr, useDevAddrMutableOpRange, target,
+ newDescParentMapOp,
+ argIface.getUseDeviceAddrBlockArgsStart() +
+ argIface.numUseDeviceAddrBlockArgs());
- mlir::Value newDescParentMapOp = builder.create<mlir::omp::MapInfoOp>(
- op->getLoc(), op.getResult().getType(), descriptor,
- mlir::TypeAttr::get(fir::unwrapRefType(descriptor.getType())),
- /*varPtrPtr=*/mlir::Value{},
- /*members=*/mlir::SmallVector<mlir::Value>{baseAddr},
- /*members_index=*/
- mlir::DenseIntElementsAttr::get(
- mlir::VectorType::get(
- llvm::ArrayRef<int64_t>({1, 1}),
- mlir::IntegerType::get(builder.getContext(), 32)),
- llvm::ArrayRef<int32_t>({0})),
- /*bounds=*/mlir::SmallVector<mlir::Value>{},
- builder.getIntegerAttr(builder.getIntegerType(64, false),
- op.getMapType().value()),
- op.getMapCaptureTypeAttr(), op.getNameAttr(), op.getPartialMapAttr());
- op.replaceAllUsesWith(newDescParentMapOp);
- op->erase();
+ mlir::OperandRange useDevPtrArr = targetDataOp.getUseDevicePtrVars();
+ mlir::MutableOperandRange useDevPtrMutableOpRange =
+ targetDataOp.getUseDevicePtrVarsMutable();
+ addOperands(useDevPtrArr, useDevPtrMutableOpRange, target,
+ newDescParentMapOp,
+ argIface.getUseDevicePtrBlockArgsStart() +
+ argIface.numUseDevicePtrBlockArgs());
+ }
}
// We add all mapped record members not directly used in the target region
diff --git a/flang/test/Lower/OpenMP/allocatable-map.f90 b/flang/test/Lower/OpenMP/allocatable-map.f90
index a9f576a6f09992..c1f94f41901489 100644
--- a/flang/test/Lower/OpenMP/allocatable-map.f90
+++ b/flang/test/Lower/OpenMP/allocatable-map.f90
@@ -4,7 +4,7 @@
!HLFIRDIALECT: %[[BOX_OFF:.*]] = fir.box_offset %[[POINTER]]#1 base_addr : (!fir.ref<!fir.box<!fir.ptr<i32>>>) -> !fir.llvm_ptr<!fir.ref<i32>>
!HLFIRDIALECT: %[[POINTER_MAP_MEMBER:.*]] = omp.map.info var_ptr(%[[POINTER]]#1 : !fir.ref<!fir.box<!fir.ptr<i32>>>, i32) var_ptr_ptr(%[[BOX_OFF]] : !fir.llvm_ptr<!fir.ref<i32>>) map_clauses(tofrom) capture(ByRef) -> !fir.llvm_ptr<!fir.ref<i32>> {name = ""}
!HLFIRDIALECT: %[[POINTER_MAP:.*]] = omp.map.info var_ptr(%[[POINTER]]#1 : !fir.ref<!fir.box<!fir.ptr<i32>>>, !fir.box<!fir.ptr<i32>>) map_clauses(tofrom) capture(ByRef) members(%[[POINTER_MAP_MEMBER]] : [0] : !fir.llvm_ptr<!fir.ref<i32>>) -> !fir.ref<!fir.box<!fir.ptr<i32>>> {name = "point"}
-!HLFIRDIALECT: omp.target map_entries(%[[POINTER_MAP_MEMBER]] -> {{.*}}, %[[POINTER_MAP]] -> {{.*}} : !fir.llvm_ptr<!fir.ref<i32>>, !fir.ref<!fir.box<!fir.ptr<i32>>>) {
+!HLFIRDIALECT: omp.target map_entries(%[[POINTER_MAP]] -> {{.*}}, %[[POINTER_MAP_MEMBER]] -> {{.*}} : !fir.ref<!fir.box<!fir.ptr<i32>>>, !fir.llvm_ptr<!fir.ref<i32>>) {
subroutine pointer_routine()
integer, pointer :: point
!$omp target map(tofrom:point)
diff --git a/flang/test/Lower/OpenMP/array-bounds.f90 b/flang/test/Lower/OpenMP/array-bounds.f90
index 09498ca6cdde99..40fd276f10462b 100644
--- a/flang/test/Lower/OpenMP/array-bounds.f90
+++ b/flang/test/Lower/OpenMP/array-bounds.f90
@@ -53,7 +53,7 @@ module assumed_array_routines
!HOST: %[[VAR_PTR_PTR:.*]] = fir.box_offset %0 base_addr : (!fir.ref<!fir.box<!fir.array<?xi32>>>) -> !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>
!HOST: %[[MAP_INFO_MEMBER:.*]] = omp.map.info var_ptr(%[[INTERMEDIATE_ALLOCA]] : !fir.ref<!fir.box<!fir.array<?xi32>>>, !fir.array<?xi32>) var_ptr_ptr(%[[VAR_PTR_PTR]] : !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>) map_clauses(tofrom) capture(ByRef) bounds(%[[BOUNDS]]) -> !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>> {name = ""}
!HOST: %[[MAP:.*]] = omp.map.info var_ptr(%[[INTERMEDIATE_ALLOCA]] : !fir.ref<!fir.box<!fir.array<?xi32>>>, !fir.box<!fir.array<?xi32>>) map_clauses(tofrom) capture(ByRef) members(%[[MAP_INFO_MEMBER]] : [0] : !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>) -> !fir.ref<!fir.array<?xi32>> {name = "arr_read_write(2:5)"}
-!HOST: omp.target map_entries(%[[MAP_INFO_MEMBER]] -> %{{.*}}, %[[MAP]] -> %{{.*}}, {{.*}} -> {{.*}} : !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>, !fir.ref<!fir.array<?xi32>>, !fir.ref<i32>) {
+!HOST: omp.target map_entries(%[[MAP]] -> %{{.*}}, {{.*}} -> {{.*}}, %[[MAP_INFO_MEMBER]] -> %{{.*}} : !fir.ref<!fir.array<?xi32>>, !fir.ref<i32>, !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>) {
subroutine assumed_shape_array(arr_read_write)
integer, intent(inout) :: arr_read_write(:)
diff --git a/flang/test/Lower/OpenMP/target.f90 b/flang/test/Lower/OpenMP/target.f90
index 63a43e750979d5..f8bbde93072e91 100644
--- a/flang/test/Lower/OpenMP/target.f90
+++ b/flang/test/Lower/OpenMP/target.f90
@@ -528,9 +528,9 @@ subroutine omp_target_device_addr
!CHECK: %[[MAP:.*]] = omp.map.info var_ptr({{.*}} : !fir.ref<!fir.box<!fir.ptr<i32>>>, !fir.box<!fir.ptr<i32>>) map_clauses(tofrom) capture(ByRef) members(%[[MAP_MEMBERS]] : [0] : !fir.llvm_ptr<!fir.ref<i32>>) -> !fir.ref<!fir.box<!fir.ptr<i32>>> {name = "a"}
!CHECK: %[[DEV_ADDR_MEMBERS:.*]] = omp.map.info var_ptr({{.*}} : !fir.ref<!fir.box<!fir.ptr<i32>>>, i32) var_ptr_ptr({{.*}} : !fir.llvm_ptr<!fir.ref<i32>>) map_clauses(tofrom) capture(ByRef) -> !fir.llvm_ptr<!fir.ref<i32>> {name = ""}
!CHECK: %[[DEV_ADDR:.*]] = omp.map.info var_ptr({{.*}} : !fir.ref<!fir.box<!fir.ptr<i32>>>, !fir.box<!fir.ptr<i32>>) map_clauses(tofrom) capture(ByRef) members(%[[DEV_ADDR_MEMBERS]] : [0] : !fir.llvm_ptr<!fir.ref<i32>>) -> !fir.ref<!fir.box<!fir.ptr<i32>>> {name = "a"}
- !CHECK: omp.target_data map_entries(%[[MAP_MEMBERS]], %[[MAP]] : {{.*}}) use_device_addr(%[[DEV_ADDR_MEMBERS]] -> %[[ARG_0:.*]], %[[DEV_ADDR]] -> %[[ARG_1:.*]] : !fir.llvm_ptr<!fir.ref<i32>>, !fir.ref<!fir.box<!fir.ptr<i32>>>) {
+ !CHECK: omp.target_data map_entries(%[[MAP]], %[[MAP_MEMBERS]] : {{.*}}) use_device_addr(%[[DEV_ADDR]] -> %[[ARG_0:.*]], %[[DEV_ADDR_MEMBERS]] -> %[[ARG_1:.*]] : !fir.ref<!fir.box<!fir.ptr<i32>>>, !fir.llvm_ptr<!fir.ref<i32>>) {
!$omp target data map(tofrom: a) use_device_addr(a)
- !CHECK: %[[VAL_1_DECL:.*]]:2 = hlfir.declare %[[ARG_1]] {fortran_attrs = #fir.var_attrs<pointer>, uniq_name = "_QFomp_target_device_addrEa"} : (!fir.ref<!fir.box<!fir.ptr<i32>>>) -> (!fir.ref<!fir.box<!fir.ptr<i32>>>, !fir.ref<!fir.box<!fir.ptr<i32>>>)
+ !CHECK: %[[VAL_1_DECL:.*]]:2 = hlfir.declare %[[ARG_0]] {fortran_attrs = #fir.var_attrs<pointer>, uniq_name = "_QFomp_target_device_addrEa"} : (!fir.ref<!fir.box<!fir.ptr<i32>>>) -> (!fir.ref<!fir.box<!fir.ptr<i32>>>, !fir.ref<!fir.box<!fir.ptr<i32>>>)
!CHECK: %[[C10:.*]] = arith.constant 10 : i32
!CHECK: %[[A_BOX:.*]] = fir.load %[[VAL_1_DECL]]#0 : !fir.ref<!fir.box<!fir.ptr<i32>>>
!CHECK: %[[A_ADDR:.*]] = fir.box_addr %[[A_BOX]] : (!fir.box<!fir.ptr<i32>>) -> !fir.ptr<i32>
diff --git a/flang/test/Lower/OpenMP/use-device-ptr-to-use-device-addr.f90 b/flang/test/Lower/OpenMP/use-device-ptr-to-use-device-addr.f90
index cb26246a6e80f0..8c1abad8eaa8d5 100644
--- a/flang/test/Lower/OpenMP/use-device-ptr-to-use-device-addr.f90
+++ b/flang/test/Lower/OpenMP/use-device-ptr-to-use-device-addr.f90
@@ -6,7 +6,8 @@
! use_device_ptr to use_device_addr works, without breaking any functionality.
!CHECK: func.func @{{.*}}only_use_device_ptr()
-!CHECK: omp.target_data use_device_addr(%{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}} : !fir.llvm_ptr<!fir.ref<!fir.array<?xf32>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>, !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>) use_device_ptr(%{{.*}} -> %{{.*}} : !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>) {
+
+!CHECK: omp.target_data use_device_addr(%{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}} : !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>, !fir.llvm_ptr<!fir.ref<!fir.array<?xf32>>>, !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>) use_device_ptr(%{{.*}} -> %{{.*}} : !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>) {
subroutine only_use_device_ptr
use iso_c_binding
integer, pointer, dimension(:) :: array
@@ -18,7 +19,7 @@ subroutine only_use_device_ptr
end subroutine
!CHECK: func.func @{{.*}}mix_use_device_ptr_and_addr()
-!CHECK: omp.target_data use_device_addr(%{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}} : !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>, !fir.llvm_ptr<!fir.ref<!fir.array<?xf32>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>) use_device_ptr({{.*}} : !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>) {
+!CHECK: omp.target_data use_device_addr(%{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}} : !fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>, !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>, !fir.llvm_ptr<!fir.ref<!fir.array<?xf32>>>) use_device_ptr({{.*}} : !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>) {
subroutine mix_use_device_ptr_and_addr
use iso_c_binding
integer, pointer, dimension(:) :: array
@@ -30,7 +31,7 @@ subroutine mix_use_device_ptr_and_addr
end subroutine
!CHECK: func.func @{{.*}}only_use_device_addr()
- !CHECK: omp.target_data use_device_addr(%{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}} : !fir.llvm_ptr<!fir.ref<!fir.array<?xf32>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>, !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>, !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>) {
+ !CHECK: omp.target_data use_device_addr(%{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}} : !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>, !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>, !fir.llvm_ptr<!fir.ref<!fir.array<?xf32>>>, !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>) {
subroutine only_use_device_addr
use iso_c_binding
integer, pointer, dimension(:) :: array
@@ -42,7 +43,7 @@ subroutine only_use_device_addr
end subroutine
!CHECK: func.func @{{.*}}mix_use_device_ptr_and_addr_and_map()
- !CHECK: omp.target_data map_entries(%{{.*}}, %{{.*}} : !fir.ref<i32>, !fir.ref<i32>) use_device_addr(%{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}} : !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>, !fir.llvm_ptr<!fir.ref<!fir.array<?xf32>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>) use_device_ptr(%{{.*}} : !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>) {
+ !CHECK: omp.target_data map_entries(%{{.*}}, %{{.*}} : !fir.ref<i32>, !fir.ref<i32>) use_device_addr(%{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}} : !fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>, !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>, !fir.llvm_ptr<!fir.ref<!fir.array<?xf32>>>) use_device_ptr(%{{.*}} : !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>) {
subroutine mix_use_device_ptr_and_addr_and_map
use iso_c_binding
integer :: i, j
@@ -55,7 +56,7 @@ subroutine mix_use_device_ptr_and_addr_and_map
end subroutine
!CHECK: func.func @{{.*}}only_use_map()
- !CHECK: omp.target_data map_entries(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : !fir.llvm_ptr<!fir.ref<!fir.array<?xf32>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>, !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>, !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>) {
+ !CHECK: omp.target_data map_entries(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>, !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>, !fir.llvm_ptr<!fir.ref<!fir.array<?xf32>>>, !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>) {
subroutine only_use_map
use iso_c_binding
integer, pointer, dimension(:) :: array
diff --git a/flang/test/Transforms/omp-map-info-finalization.fir b/flang/test/Transforms/omp-map-info-finalization.fir
index fa7b65d41929b7..de0ad2143fc853 100644
--- a/flang/test/Transforms/omp-map-info-finalization.fir
+++ b/flang/test/Transforms/omp-map-info-finalization.fir
@@ -39,7 +39,7 @@ module attributes {omp.is_target_device = false} {
// CHECK: %[[BASE_ADDR_OFF_2:.*]] = fir.box_offset %[[ALLOCA]] base_addr : (!fir.ref<!fir.box<!fir.array<?xi32>>>) -> !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>
// CHECK: %[[DESC_MEMBER_MAP_2:.*]] = omp.map.info var_ptr(%[[ALLOCA]] : !fir.ref<!fir.box<!fir.array<?xi32>>>, !fir.array<?xi32>) var_ptr_ptr(%[[BASE_ADDR_OFF_2]] : !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>) map_clauses(from) capture(ByRef) bounds(%[[BOUNDS]]) -> !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>> {name = ""}
// CHECK: %[[DESC_PARENT_MAP_2:.*]] = omp.map.info var_ptr(%[[ALLOCA]] : !fir.ref<!fir.box<!fir.array<?xi32>>>, !fir.box<!fir.array<?xi32>>) map_clauses(from) capture(ByRef) members(%[[DESC_MEMBER_MAP_2]] : [0] : !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>) -> !fir.ref<!fir.array<?xi32>>
-// CHECK: omp.target map_entries(%[[DESC_MEMBER_MAP]] -> %[[ARG1:.*]], %[[DESC_PARENT_MAP]] -> %[[ARG2:.*]], %[[DESC_MEMBER_MAP_2]] -> %[[ARG3:.*]], %[[DESC_PARENT_MAP_2]] -> %[[ARG4:.*]] : {{.*}}) {
+// CHECK: omp.target map_entries(%[[DESC_PARENT_MAP]] -> %[[ARG1:.*]], %[[DESC_PARENT_MAP_2]] -> %[[ARG2:.*]], %[[DESC_MEMBER_MAP]] -> %[[ARG3:.*]], %[[DESC_MEMBER_MAP_2]] -> %[[ARG4:.*]] : {{.*}}) {
// -----
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 27cd38dc3c62d9..fbbaa909fd3dd8 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -2327,21 +2327,20 @@ static void collectMapDataFromMapOperands(
mapData.IsAMember.push_back(checkIsAMember(mapVars, mapOp));
}
- auto findMapInfo = [&mapData](llvm::Value *val,
- llvm::OpenMPIRBuilder::DeviceInfoTy devInfoTy) {
- unsigned index = 0;
- bool found = false;
- for (llvm::Value *basePtr : mapData.OriginalValue) {
- if (basePtr == val && mapData.IsAMapping[index]) {
- found = true;
- mapData.Types[index] |=
- llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM;
- mapData.DevicePointers[index] = devInfoTy;
- }
- index++;
- }
- return found;
- };
+ // This function alters the original mapped pointers type if it was present in
+ // a map clause as well as being present in a useDevAddr/Ptr clause.
+ auto alterAndCreateUseDevMapType =
+ [&mapData](llvm::Value *val,
+ llvm::OpenMPIRBuilder::DeviceInfoTy devInfoTy) {
+ for (auto [i, origVal] : llvm::enumerate(mapData.OriginalValue)) {
+ if (origVal == val && mapData.IsAMapping[i]) {
+ mapData.Types[i] |=
+ llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM;
+ mapData.Devi...
[truncated]
|
|
@llvm/pr-subscribers-mlir-llvm Author: None (agozillon) ChangesThis patch primarily updates the MapInfoFinalization pass to utilise the BlockArgument interface. It also shuffles newly added arguments the MapInfoFinalization passes to the end of the BlockArg/Relevant MapInfo lists, instead of one prior to the owning descriptor type. During this it was noted that the use_device_ptr/addr handling of target data was a little bit too order dependent so I've attempted to make it less so, as we cannot depend on argument ordering to be the same as Fortran for any future frontends. Patch is 35.79 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/113919.diff 10 Files Affected:
diff --git a/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp b/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp
index 7ebeb51cf3dec7..6eb65b25594295 100644
--- a/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp
+++ b/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp
@@ -125,61 +125,82 @@ class MapInfoFinalizationPass
// TODO: map the addendum segment of the descriptor, similarly to the
// above base address/data pointer member.
- auto addOperands = [&](mlir::OperandRange &operandsArr,
- mlir::MutableOperandRange &mutableOpRange,
- auto directiveOp) {
- llvm::SmallVector<mlir::Value> newMapOps;
- for (size_t i = 0; i < operandsArr.size(); ++i) {
- if (operandsArr[i] == op) {
- // Push new implicit maps generated for the descriptor.
- newMapOps.push_back(baseAddr);
+ mlir::omp::MapInfoOp newDescParentMapOp =
+ builder.create<mlir::omp::MapInfoOp>(
+ op->getLoc(), op.getResult().getType(), descriptor,
+ mlir::TypeAttr::get(fir::unwrapRefType(descriptor.getType())),
+ /*varPtrPtr=*/mlir::Value{},
+ /*members=*/mlir::SmallVector<mlir::Value>{baseAddr},
+ /*members_index=*/
+ mlir::DenseIntElementsAttr::get(
+ mlir::VectorType::get(
+ llvm::ArrayRef<int64_t>({1, 1}),
+ mlir::IntegerType::get(builder.getContext(), 32)),
+ llvm::ArrayRef<int32_t>({0})),
+ /*bounds=*/mlir::SmallVector<mlir::Value>{},
+ builder.getIntegerAttr(builder.getIntegerType(64, false),
+ op.getMapType().value()),
+ op.getMapCaptureTypeAttr(), op.getNameAttr(),
+ op.getPartialMapAttr());
+ op.replaceAllUsesWith(newDescParentMapOp.getResult());
+ op->erase();
- // for TargetOp's which have IsolatedFromAbove we must align the
- // new additional map operand with an appropriate BlockArgument,
- // as the printing and later processing currently requires a 1:1
- // mapping of BlockArgs to MapInfoOp's at the same placement in
- // each array (BlockArgs and MapOperands).
- if (directiveOp) {
- directiveOp.getRegion().insertArgument(i, baseAddr.getType(), loc);
+ auto addOperands = [&](mlir::OperandRange &mapVarsArr,
+ mlir::MutableOperandRange &mutableOpRange,
+ mlir::Operation *directiveOp,
+ mlir::omp::MapInfoOp newDesc,
+ unsigned blockArgInsertIndex = 0,
+ bool insertBlockArgs = true) {
+ if (llvm::is_contained(mapVarsArr, newDesc.getResult())) {
+ llvm::SmallVector<mlir::Value> newMapOps{mapVarsArr};
+ for (auto mapMember : newDesc.getMembers()) {
+ if (!llvm::is_contained(mapVarsArr, mapMember)) {
+ newMapOps.push_back(mapMember);
+ if (directiveOp && insertBlockArgs) {
+ directiveOp->getRegion(0).insertArgument(
+ blockArgInsertIndex, mapMember.getType(), mapMember.getLoc());
+ }
+ blockArgInsertIndex++;
}
}
- newMapOps.push_back(operandsArr[i]);
+ mutableOpRange.assign(newMapOps);
}
- mutableOpRange.assign(newMapOps);
};
+
+ auto argIface =
+ llvm::dyn_cast<mlir::omp::BlockArgOpenMPOpInterface>(target);
+
if (auto mapClauseOwner =
llvm::dyn_cast<mlir::omp::MapClauseOwningOpInterface>(target)) {
- mlir::OperandRange mapOperandsArr = mapClauseOwner.getMapVars();
+ mlir::OperandRange mapVarsArr = mapClauseOwner.getMapVars();
mlir::MutableOperandRange mapMutableOpRange =
mapClauseOwner.getMapVarsMutable();
- mlir::omp::TargetOp targetOp =
- llvm::dyn_cast<mlir::omp::TargetOp>(target);
- addOperands(mapOperandsArr, mapMutableOpRange, targetOp);
+ unsigned blockArgInsertIndex =
+ argIface
+ ? argIface.getMapBlockArgsStart() + argIface.numMapBlockArgs()
+ : 0;
+ addOperands(mapVarsArr, mapMutableOpRange, argIface.getOperation(),
+ newDescParentMapOp, blockArgInsertIndex,
+ !llvm::isa<mlir::omp::TargetDataOp>(target));
}
+
if (auto targetDataOp = llvm::dyn_cast<mlir::omp::TargetDataOp>(target)) {
mlir::OperandRange useDevAddrArr = targetDataOp.getUseDeviceAddrVars();
mlir::MutableOperandRange useDevAddrMutableOpRange =
targetDataOp.getUseDeviceAddrVarsMutable();
- addOperands(useDevAddrArr, useDevAddrMutableOpRange, targetDataOp);
- }
+ addOperands(useDevAddrArr, useDevAddrMutableOpRange, target,
+ newDescParentMapOp,
+ argIface.getUseDeviceAddrBlockArgsStart() +
+ argIface.numUseDeviceAddrBlockArgs());
- mlir::Value newDescParentMapOp = builder.create<mlir::omp::MapInfoOp>(
- op->getLoc(), op.getResult().getType(), descriptor,
- mlir::TypeAttr::get(fir::unwrapRefType(descriptor.getType())),
- /*varPtrPtr=*/mlir::Value{},
- /*members=*/mlir::SmallVector<mlir::Value>{baseAddr},
- /*members_index=*/
- mlir::DenseIntElementsAttr::get(
- mlir::VectorType::get(
- llvm::ArrayRef<int64_t>({1, 1}),
- mlir::IntegerType::get(builder.getContext(), 32)),
- llvm::ArrayRef<int32_t>({0})),
- /*bounds=*/mlir::SmallVector<mlir::Value>{},
- builder.getIntegerAttr(builder.getIntegerType(64, false),
- op.getMapType().value()),
- op.getMapCaptureTypeAttr(), op.getNameAttr(), op.getPartialMapAttr());
- op.replaceAllUsesWith(newDescParentMapOp);
- op->erase();
+ mlir::OperandRange useDevPtrArr = targetDataOp.getUseDevicePtrVars();
+ mlir::MutableOperandRange useDevPtrMutableOpRange =
+ targetDataOp.getUseDevicePtrVarsMutable();
+ addOperands(useDevPtrArr, useDevPtrMutableOpRange, target,
+ newDescParentMapOp,
+ argIface.getUseDevicePtrBlockArgsStart() +
+ argIface.numUseDevicePtrBlockArgs());
+ }
}
// We add all mapped record members not directly used in the target region
diff --git a/flang/test/Lower/OpenMP/allocatable-map.f90 b/flang/test/Lower/OpenMP/allocatable-map.f90
index a9f576a6f09992..c1f94f41901489 100644
--- a/flang/test/Lower/OpenMP/allocatable-map.f90
+++ b/flang/test/Lower/OpenMP/allocatable-map.f90
@@ -4,7 +4,7 @@
!HLFIRDIALECT: %[[BOX_OFF:.*]] = fir.box_offset %[[POINTER]]#1 base_addr : (!fir.ref<!fir.box<!fir.ptr<i32>>>) -> !fir.llvm_ptr<!fir.ref<i32>>
!HLFIRDIALECT: %[[POINTER_MAP_MEMBER:.*]] = omp.map.info var_ptr(%[[POINTER]]#1 : !fir.ref<!fir.box<!fir.ptr<i32>>>, i32) var_ptr_ptr(%[[BOX_OFF]] : !fir.llvm_ptr<!fir.ref<i32>>) map_clauses(tofrom) capture(ByRef) -> !fir.llvm_ptr<!fir.ref<i32>> {name = ""}
!HLFIRDIALECT: %[[POINTER_MAP:.*]] = omp.map.info var_ptr(%[[POINTER]]#1 : !fir.ref<!fir.box<!fir.ptr<i32>>>, !fir.box<!fir.ptr<i32>>) map_clauses(tofrom) capture(ByRef) members(%[[POINTER_MAP_MEMBER]] : [0] : !fir.llvm_ptr<!fir.ref<i32>>) -> !fir.ref<!fir.box<!fir.ptr<i32>>> {name = "point"}
-!HLFIRDIALECT: omp.target map_entries(%[[POINTER_MAP_MEMBER]] -> {{.*}}, %[[POINTER_MAP]] -> {{.*}} : !fir.llvm_ptr<!fir.ref<i32>>, !fir.ref<!fir.box<!fir.ptr<i32>>>) {
+!HLFIRDIALECT: omp.target map_entries(%[[POINTER_MAP]] -> {{.*}}, %[[POINTER_MAP_MEMBER]] -> {{.*}} : !fir.ref<!fir.box<!fir.ptr<i32>>>, !fir.llvm_ptr<!fir.ref<i32>>) {
subroutine pointer_routine()
integer, pointer :: point
!$omp target map(tofrom:point)
diff --git a/flang/test/Lower/OpenMP/array-bounds.f90 b/flang/test/Lower/OpenMP/array-bounds.f90
index 09498ca6cdde99..40fd276f10462b 100644
--- a/flang/test/Lower/OpenMP/array-bounds.f90
+++ b/flang/test/Lower/OpenMP/array-bounds.f90
@@ -53,7 +53,7 @@ module assumed_array_routines
!HOST: %[[VAR_PTR_PTR:.*]] = fir.box_offset %0 base_addr : (!fir.ref<!fir.box<!fir.array<?xi32>>>) -> !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>
!HOST: %[[MAP_INFO_MEMBER:.*]] = omp.map.info var_ptr(%[[INTERMEDIATE_ALLOCA]] : !fir.ref<!fir.box<!fir.array<?xi32>>>, !fir.array<?xi32>) var_ptr_ptr(%[[VAR_PTR_PTR]] : !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>) map_clauses(tofrom) capture(ByRef) bounds(%[[BOUNDS]]) -> !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>> {name = ""}
!HOST: %[[MAP:.*]] = omp.map.info var_ptr(%[[INTERMEDIATE_ALLOCA]] : !fir.ref<!fir.box<!fir.array<?xi32>>>, !fir.box<!fir.array<?xi32>>) map_clauses(tofrom) capture(ByRef) members(%[[MAP_INFO_MEMBER]] : [0] : !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>) -> !fir.ref<!fir.array<?xi32>> {name = "arr_read_write(2:5)"}
-!HOST: omp.target map_entries(%[[MAP_INFO_MEMBER]] -> %{{.*}}, %[[MAP]] -> %{{.*}}, {{.*}} -> {{.*}} : !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>, !fir.ref<!fir.array<?xi32>>, !fir.ref<i32>) {
+!HOST: omp.target map_entries(%[[MAP]] -> %{{.*}}, {{.*}} -> {{.*}}, %[[MAP_INFO_MEMBER]] -> %{{.*}} : !fir.ref<!fir.array<?xi32>>, !fir.ref<i32>, !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>) {
subroutine assumed_shape_array(arr_read_write)
integer, intent(inout) :: arr_read_write(:)
diff --git a/flang/test/Lower/OpenMP/target.f90 b/flang/test/Lower/OpenMP/target.f90
index 63a43e750979d5..f8bbde93072e91 100644
--- a/flang/test/Lower/OpenMP/target.f90
+++ b/flang/test/Lower/OpenMP/target.f90
@@ -528,9 +528,9 @@ subroutine omp_target_device_addr
!CHECK: %[[MAP:.*]] = omp.map.info var_ptr({{.*}} : !fir.ref<!fir.box<!fir.ptr<i32>>>, !fir.box<!fir.ptr<i32>>) map_clauses(tofrom) capture(ByRef) members(%[[MAP_MEMBERS]] : [0] : !fir.llvm_ptr<!fir.ref<i32>>) -> !fir.ref<!fir.box<!fir.ptr<i32>>> {name = "a"}
!CHECK: %[[DEV_ADDR_MEMBERS:.*]] = omp.map.info var_ptr({{.*}} : !fir.ref<!fir.box<!fir.ptr<i32>>>, i32) var_ptr_ptr({{.*}} : !fir.llvm_ptr<!fir.ref<i32>>) map_clauses(tofrom) capture(ByRef) -> !fir.llvm_ptr<!fir.ref<i32>> {name = ""}
!CHECK: %[[DEV_ADDR:.*]] = omp.map.info var_ptr({{.*}} : !fir.ref<!fir.box<!fir.ptr<i32>>>, !fir.box<!fir.ptr<i32>>) map_clauses(tofrom) capture(ByRef) members(%[[DEV_ADDR_MEMBERS]] : [0] : !fir.llvm_ptr<!fir.ref<i32>>) -> !fir.ref<!fir.box<!fir.ptr<i32>>> {name = "a"}
- !CHECK: omp.target_data map_entries(%[[MAP_MEMBERS]], %[[MAP]] : {{.*}}) use_device_addr(%[[DEV_ADDR_MEMBERS]] -> %[[ARG_0:.*]], %[[DEV_ADDR]] -> %[[ARG_1:.*]] : !fir.llvm_ptr<!fir.ref<i32>>, !fir.ref<!fir.box<!fir.ptr<i32>>>) {
+ !CHECK: omp.target_data map_entries(%[[MAP]], %[[MAP_MEMBERS]] : {{.*}}) use_device_addr(%[[DEV_ADDR]] -> %[[ARG_0:.*]], %[[DEV_ADDR_MEMBERS]] -> %[[ARG_1:.*]] : !fir.ref<!fir.box<!fir.ptr<i32>>>, !fir.llvm_ptr<!fir.ref<i32>>) {
!$omp target data map(tofrom: a) use_device_addr(a)
- !CHECK: %[[VAL_1_DECL:.*]]:2 = hlfir.declare %[[ARG_1]] {fortran_attrs = #fir.var_attrs<pointer>, uniq_name = "_QFomp_target_device_addrEa"} : (!fir.ref<!fir.box<!fir.ptr<i32>>>) -> (!fir.ref<!fir.box<!fir.ptr<i32>>>, !fir.ref<!fir.box<!fir.ptr<i32>>>)
+ !CHECK: %[[VAL_1_DECL:.*]]:2 = hlfir.declare %[[ARG_0]] {fortran_attrs = #fir.var_attrs<pointer>, uniq_name = "_QFomp_target_device_addrEa"} : (!fir.ref<!fir.box<!fir.ptr<i32>>>) -> (!fir.ref<!fir.box<!fir.ptr<i32>>>, !fir.ref<!fir.box<!fir.ptr<i32>>>)
!CHECK: %[[C10:.*]] = arith.constant 10 : i32
!CHECK: %[[A_BOX:.*]] = fir.load %[[VAL_1_DECL]]#0 : !fir.ref<!fir.box<!fir.ptr<i32>>>
!CHECK: %[[A_ADDR:.*]] = fir.box_addr %[[A_BOX]] : (!fir.box<!fir.ptr<i32>>) -> !fir.ptr<i32>
diff --git a/flang/test/Lower/OpenMP/use-device-ptr-to-use-device-addr.f90 b/flang/test/Lower/OpenMP/use-device-ptr-to-use-device-addr.f90
index cb26246a6e80f0..8c1abad8eaa8d5 100644
--- a/flang/test/Lower/OpenMP/use-device-ptr-to-use-device-addr.f90
+++ b/flang/test/Lower/OpenMP/use-device-ptr-to-use-device-addr.f90
@@ -6,7 +6,8 @@
! use_device_ptr to use_device_addr works, without breaking any functionality.
!CHECK: func.func @{{.*}}only_use_device_ptr()
-!CHECK: omp.target_data use_device_addr(%{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}} : !fir.llvm_ptr<!fir.ref<!fir.array<?xf32>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>, !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>) use_device_ptr(%{{.*}} -> %{{.*}} : !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>) {
+
+!CHECK: omp.target_data use_device_addr(%{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}} : !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>, !fir.llvm_ptr<!fir.ref<!fir.array<?xf32>>>, !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>) use_device_ptr(%{{.*}} -> %{{.*}} : !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>) {
subroutine only_use_device_ptr
use iso_c_binding
integer, pointer, dimension(:) :: array
@@ -18,7 +19,7 @@ subroutine only_use_device_ptr
end subroutine
!CHECK: func.func @{{.*}}mix_use_device_ptr_and_addr()
-!CHECK: omp.target_data use_device_addr(%{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}} : !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>, !fir.llvm_ptr<!fir.ref<!fir.array<?xf32>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>) use_device_ptr({{.*}} : !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>) {
+!CHECK: omp.target_data use_device_addr(%{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}} : !fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>, !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>, !fir.llvm_ptr<!fir.ref<!fir.array<?xf32>>>) use_device_ptr({{.*}} : !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>) {
subroutine mix_use_device_ptr_and_addr
use iso_c_binding
integer, pointer, dimension(:) :: array
@@ -30,7 +31,7 @@ subroutine mix_use_device_ptr_and_addr
end subroutine
!CHECK: func.func @{{.*}}only_use_device_addr()
- !CHECK: omp.target_data use_device_addr(%{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}} : !fir.llvm_ptr<!fir.ref<!fir.array<?xf32>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>, !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>, !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>) {
+ !CHECK: omp.target_data use_device_addr(%{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}} : !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>, !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>, !fir.llvm_ptr<!fir.ref<!fir.array<?xf32>>>, !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>) {
subroutine only_use_device_addr
use iso_c_binding
integer, pointer, dimension(:) :: array
@@ -42,7 +43,7 @@ subroutine only_use_device_addr
end subroutine
!CHECK: func.func @{{.*}}mix_use_device_ptr_and_addr_and_map()
- !CHECK: omp.target_data map_entries(%{{.*}}, %{{.*}} : !fir.ref<i32>, !fir.ref<i32>) use_device_addr(%{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}} : !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>, !fir.llvm_ptr<!fir.ref<!fir.array<?xf32>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>) use_device_ptr(%{{.*}} : !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>) {
+ !CHECK: omp.target_data map_entries(%{{.*}}, %{{.*}} : !fir.ref<i32>, !fir.ref<i32>) use_device_addr(%{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}} : !fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>, !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>, !fir.llvm_ptr<!fir.ref<!fir.array<?xf32>>>) use_device_ptr(%{{.*}} : !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>) {
subroutine mix_use_device_ptr_and_addr_and_map
use iso_c_binding
integer :: i, j
@@ -55,7 +56,7 @@ subroutine mix_use_device_ptr_and_addr_and_map
end subroutine
!CHECK: func.func @{{.*}}only_use_map()
- !CHECK: omp.target_data map_entries(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : !fir.llvm_ptr<!fir.ref<!fir.array<?xf32>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>, !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>, !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>) {
+ !CHECK: omp.target_data map_entries(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>, !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>, !fir.llvm_ptr<!fir.ref<!fir.array<?xf32>>>, !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>) {
subroutine only_use_map
use iso_c_binding
integer, pointer, dimension(:) :: array
diff --git a/flang/test/Transforms/omp-map-info-finalization.fir b/flang/test/Transforms/omp-map-info-finalization.fir
index fa7b65d41929b7..de0ad2143fc853 100644
--- a/flang/test/Transforms/omp-map-info-finalization.fir
+++ b/flang/test/Transforms/omp-map-info-finalization.fir
@@ -39,7 +39,7 @@ module attributes {omp.is_target_device = false} {
// CHECK: %[[BASE_ADDR_OFF_2:.*]] = fir.box_offset %[[ALLOCA]] base_addr : (!fir.ref<!fir.box<!fir.array<?xi32>>>) -> !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>
// CHECK: %[[DESC_MEMBER_MAP_2:.*]] = omp.map.info var_ptr(%[[ALLOCA]] : !fir.ref<!fir.box<!fir.array<?xi32>>>, !fir.array<?xi32>) var_ptr_ptr(%[[BASE_ADDR_OFF_2]] : !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>) map_clauses(from) capture(ByRef) bounds(%[[BOUNDS]]) -> !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>> {name = ""}
// CHECK: %[[DESC_PARENT_MAP_2:.*]] = omp.map.info var_ptr(%[[ALLOCA]] : !fir.ref<!fir.box<!fir.array<?xi32>>>, !fir.box<!fir.array<?xi32>>) map_clauses(from) capture(ByRef) members(%[[DESC_MEMBER_MAP_2]] : [0] : !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>) -> !fir.ref<!fir.array<?xi32>>
-// CHECK: omp.target map_entries(%[[DESC_MEMBER_MAP]] -> %[[ARG1:.*]], %[[DESC_PARENT_MAP]] -> %[[ARG2:.*]], %[[DESC_MEMBER_MAP_2]] -> %[[ARG3:.*]], %[[DESC_PARENT_MAP_2]] -> %[[ARG4:.*]] : {{.*}}) {
+// CHECK: omp.target map_entries(%[[DESC_PARENT_MAP]] -> %[[ARG1:.*]], %[[DESC_PARENT_MAP_2]] -> %[[ARG2:.*]], %[[DESC_MEMBER_MAP]] -> %[[ARG3:.*]], %[[DESC_MEMBER_MAP_2]] -> %[[ARG4:.*]] : {{.*}}) {
// -----
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 27cd38dc3c62d9..fbbaa909fd3dd8 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -2327,21 +2327,20 @@ static void collectMapDataFromMapOperands(
mapData.IsAMember.push_back(checkIsAMember(mapVars, mapOp));
}
- auto findMapInfo = [&mapData](llvm::Value *val,
- llvm::OpenMPIRBuilder::DeviceInfoTy devInfoTy) {
- unsigned index = 0;
- bool found = false;
- for (llvm::Value *basePtr : mapData.OriginalValue) {
- if (basePtr == val && mapData.IsAMapping[index]) {
- found = true;
- mapData.Types[index] |=
- llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM;
- mapData.DevicePointers[index] = devInfoTy;
- }
- index++;
- }
- return found;
- };
+ // This function alters the original mapped pointers type if it was present in
+ // a map clause as well as being present in a useDevAddr/Ptr clause.
+ auto alterAndCreateUseDevMapType =
+ [&mapData](llvm::Value *val,
+ llvm::OpenMPIRBuilder::DeviceInfoTy devInfoTy) {
+ for (auto [i, origVal] : llvm::enumerate(mapData.OriginalValue)) {
+ if (origVal == val && mapData.IsAMapping[i]) {
+ mapData.Types[i] |=
+ llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM;
+ mapData.Devi...
[truncated]
|
| mapData.IsAMapping.push_back(false); | ||
| mapData.IsAMember.push_back(checkIsAMember(useDevOperands, mapOp)); | ||
| } | ||
| llvm::omp::OpenMPOffloadMappingFlags mapType = |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you please explain the change removing the if clause?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I currently depend on the fact that every use device ptr/addr regardless of if it has appeared in a map clause or not is inserted into the map data list here: https://github.com/llvm/llvm-project/pull/113919/files#diff-2cbb5651f4570d81d55ac4198deda0f6f7341b2503479752ef2295da3774c586R3014 to break the ordering requirements, perhaps there's a more trivial way to do so though!
I also think every clause that gets lowered from a map info operation should have it's information inserted, incase we depend on it at a later stage or need to check it for other information, but that might just be because it's something I've found useful!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm a but vague on how this is necessary for ordering purposes but I can imagine how it might be needed.
Let me share my perspective:
UseDevicePtr/Addr is really just like enabling an extra flag or mapTypeMpdifier for the mapClause. It should really just be part of MapType or the map clause itself but it's not because only a subset of the directives that allow the map clause also allow the useDevPtr/Addr.
The expected/common behavior is that a variable which occurs in a UseDev clause has already been mapped in one of the MapClause. However, this isn't a rule, hence when we encounter a UseDev clause variable with no previous MapInfo for it, we create a "dummy" mapping for it. This is more in the category "no idea what the user wants here, let he/she bother".
As such, generating separate MapInfo for each and every UseDev clause doesn't really make sense. Moreover, it is only recently we switched the UseDev clauses to make use of the MapOps but that was only to leverage the vivid data type support of the mapOp and nothing else.
With this philosophy in mind I don't think we should always be adding MapInfo for each UseDev, as they are a lack of information in themselves and hold no information wortthy of preservation.
skatrak
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you Andrew for looking into this! I just have some small nits and suggestions.
Could you add an MLIR to LLVM IR translation test that checks that the order in which a parent map info and a child appear in an e.g. use_device_addr clause result in equivalent IR?
| auto alterAndCreateUseDevMapType = | ||
| [&mapData](llvm::Value *val, | ||
| llvm::OpenMPIRBuilder::DeviceInfoTy devInfoTy) { | ||
| for (auto [i, origVal] : llvm::enumerate(mapData.OriginalValue)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: Consider using zip_equal instead, if possible (not sure if assignments below are possible using this approach):
| for (auto [i, origVal] : llvm::enumerate(mapData.OriginalValue)) { | |
| for (auto [origVal, isMapping, type, devicePointer] : llvm::zip_equal(mapData.OriginalValue, mapData.IsAMapping, mapData.Types, mapData.DevicePointers)) { |
| createAlteredByCaptureMap(mapData, moduleTranslation, builder); | ||
|
|
||
| auto useDevAndMapped = [&mapData](unsigned mapIdx) { | ||
| if (!mapData.IsAMapping[mapIdx]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: Negate condition and return early to reduce nesting.
mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
Outdated
Show resolved
Hide resolved
mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
Outdated
Show resolved
Hide resolved
mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
Outdated
Show resolved
Hide resolved
edeb77b to
9a0797b
Compare
… and modify use_device_ptr/addr to be order independent This patch primarily updates the MapInfoFinalization pass to utilise the BlockArgument interface. It also shuffles newly added arguments the MapInfoFinalization passes to the end of the BlockArg/Relevant MapInfo lists, instead of one prior to the owning descriptor type. During this it was noted that the use_device_ptr/addr handling of target data was a little bit too order dependent so I've attempted to make it less so, as we cannot depend on argument ordering to be the same as Fortran for any future frontends.
9a0797b to
612f05d
Compare
TIFitis
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for making the changes. I am happy with the patch 👍🏽
skatrak
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you Andrew, I see that a couple of the spots where I had comments disappeared. Since everything still works, I'm guessing these changes weren't needed, so LGTM!
|
Thank you very much @skatrak :-)! Yes, I cut the PR down a bit after looking into some alternatives! I'll look to land this tomorrow if no further comments from anyone else are made in the meantime! |
This patch primarily updates the MapInfoFinalization pass to utilise the BlockArgument interface. It also shuffles newly added arguments the MapInfoFinalization passes to the end of the BlockArg/Relevant MapInfo lists, instead of one prior to the owning descriptor type.
During this it was noted that the use_device_ptr/addr handling of target data was a little bit too order dependent so I've attempted to make it less so, as we cannot depend on argument ordering to be the same as Fortran for any future frontends.