Skip to content

Commit c6440f4

Browse files
committed
[MLIR][OpenMP] Add OpenMPToLLVMIRTranslation support for is_device_ptr
1 parent 52a71ea commit c6440f4

File tree

4 files changed

+100
-17
lines changed

4 files changed

+100
-17
lines changed

flang/test/Integration/OpenMP/map-types-and-sizes.f90

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,15 @@ subroutine mapType_array
3333
!$omp end target
3434
end subroutine mapType_array
3535

36+
!CHECK: @.offload_sizes{{.*}} = private unnamed_addr constant [1 x i64] [i64 8]
37+
!CHECK: @.offload_maptypes{{.*}} = private unnamed_addr constant [1 x i64] [i64 288]
38+
subroutine mapType_is_device_ptr
39+
use iso_c_binding, only : c_ptr
40+
type(c_ptr) :: p
41+
!$omp target is_device_ptr(p)
42+
!$omp end target
43+
end subroutine mapType_is_device_ptr
44+
3645
!CHECK: @.offload_sizes{{.*}} = private unnamed_addr constant [4 x i64] [i64 0, i64 24, i64 8, i64 0]
3746
!CHECK: @.offload_maptypes{{.*}} = private unnamed_addr constant [4 x i64] [i64 32, i64 281474976711169, i64 281474976711171, i64 281474976711187]
3847
subroutine mapType_ptr

mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -332,10 +332,7 @@ static LogicalResult checkImplementationStatus(Operation &op) {
332332
op.getInReductionSyms())
333333
result = todo("in_reduction");
334334
};
335-
auto checkIsDevicePtr = [&todo](auto op, LogicalResult &result) {
336-
if (!op.getIsDevicePtrVars().empty())
337-
result = todo("is_device_ptr");
338-
};
335+
auto checkIsDevicePtr = [](auto, LogicalResult &) {};
339336
auto checkLinear = [&todo](auto op, LogicalResult &result) {
340337
if (!op.getLinearVars().empty() || !op.getLinearStepVars().empty())
341338
result = todo("linear");
@@ -3996,6 +3993,9 @@ static void collectMapDataFromMapOperands(
39963993
llvm::Value *origValue = moduleTranslation.lookupValue(offloadPtr);
39973994
auto mapType = convertClauseMapFlags(mapOp.getMapType());
39983995
auto mapTypeAlways = llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS;
3996+
bool isDevicePtr =
3997+
(mapOp.getMapType() & omp::ClauseMapFlags::storage) ==
3998+
omp::ClauseMapFlags::storage;
39993999

40004000
mapData.OriginalValue.push_back(origValue);
40014001
mapData.BasePointers.push_back(origValue);
@@ -4006,7 +4006,12 @@ static void collectMapDataFromMapOperands(
40064006
mapData.Sizes.push_back(
40074007
builder.getInt64(dl.getTypeSize(mapOp.getVarType())));
40084008
mapData.MapClause.push_back(mapOp.getOperation());
4009-
if (llvm::to_underlying(mapType & mapTypeAlways)) {
4009+
if (isDevicePtr) {
4010+
mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;
4011+
mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL;
4012+
mapData.Types.push_back(mapType);
4013+
mapData.Mappers.push_back(nullptr);
4014+
} else if (llvm::to_underlying(mapType & mapTypeAlways)) {
40104015
// Descriptors are mapped with the ALWAYS flag, since they can get
40114016
// rematerialized, so the address of the decriptor for a given object
40124017
// may change from one place to another.
@@ -4029,7 +4034,8 @@ static void collectMapDataFromMapOperands(
40294034
mapData.Names.push_back(LLVM::createMappingInformation(
40304035
mapOp.getLoc(), *moduleTranslation.getOpenMPBuilder()));
40314036
mapData.DevicePointers.push_back(
4032-
llvm::OpenMPIRBuilder::DeviceInfoTy::Address);
4037+
isDevicePtr ? llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer
4038+
: llvm::OpenMPIRBuilder::DeviceInfoTy::Address);
40334039
mapData.IsAMapping.push_back(false);
40344040
mapData.IsAMember.push_back(checkIsAMember(hasDevAddrOperands, mapOp));
40354041
}

mlir/test/Target/LLVMIR/openmp-todo.mlir

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -238,17 +238,6 @@ llvm.func @target_in_reduction(%x : !llvm.ptr) {
238238

239239
// -----
240240

241-
llvm.func @target_is_device_ptr(%x : !llvm.ptr) {
242-
// expected-error@below {{not yet implemented: Unhandled clause is_device_ptr in omp.target operation}}
243-
// expected-error@below {{LLVM Translation failed for operation: omp.target}}
244-
omp.target is_device_ptr(%x : !llvm.ptr) {
245-
omp.terminator
246-
}
247-
llvm.return
248-
}
249-
250-
// -----
251-
252241
llvm.func @target_enter_data_depend(%x: !llvm.ptr) {
253242
// expected-error@below {{not yet implemented: Unhandled clause depend in omp.target_enter_data operation}}
254243
// expected-error@below {{LLVM Translation failed for operation: omp.target_enter_data}}
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
! Validate that a device pointer allocated via OpenMP runtime APIs can be
2+
! consumed by a TARGET region using the is_device_ptr clause.
3+
! REQUIRES: flang, amdgcn-amd-amdhsa
4+
! UNSUPPORTED: nvptx64-nvidia-cuda
5+
! UNSUPPORTED: nvptx64-nvidia-cuda-LTO
6+
! UNSUPPORTED: aarch64-unknown-linux-gnu
7+
! UNSUPPORTED: aarch64-unknown-linux-gnu-LTO
8+
! UNSUPPORTED: x86_64-unknown-linux-gnu
9+
! UNSUPPORTED: x86_64-unknown-linux-gnu-LTO
10+
11+
! RUN: %libomptarget-compile-fortran-run-and-check-generic
12+
13+
program is_device_ptr_target
14+
use omp_lib
15+
use iso_c_binding
16+
implicit none
17+
18+
integer, parameter :: n = 4
19+
integer, target :: host(n)
20+
type(c_ptr) :: device_ptr
21+
integer(c_int) :: rc
22+
integer :: i
23+
24+
do i = 1, n
25+
host(i) = i
26+
end do
27+
28+
device_ptr = omp_target_alloc(int(n, c_size_t) * int(c_sizeof(host(1)), c_size_t), &
29+
omp_get_default_device())
30+
if (.not. c_associated(device_ptr)) then
31+
print *, "device alloc failed"
32+
stop 1
33+
end if
34+
35+
rc = omp_target_memcpy(device_ptr, c_loc(host), &
36+
int(n, c_size_t) * int(c_sizeof(host(1)), c_size_t), &
37+
0_c_size_t, 0_c_size_t, &
38+
omp_get_default_device(), omp_get_initial_device())
39+
if (rc .ne. 0) then
40+
print *, "host->device memcpy failed"
41+
call omp_target_free(device_ptr, omp_get_default_device())
42+
stop 1
43+
end if
44+
45+
call fill_on_device(device_ptr)
46+
47+
rc = omp_target_memcpy(c_loc(host), device_ptr, &
48+
int(n, c_size_t) * int(c_sizeof(host(1)), c_size_t), &
49+
0_c_size_t, 0_c_size_t, &
50+
omp_get_initial_device(), omp_get_default_device())
51+
call omp_target_free(device_ptr, omp_get_default_device())
52+
53+
if (rc .ne. 0) then
54+
print *, "device->host memcpy failed"
55+
stop 1
56+
end if
57+
58+
if (all(host == [2, 4, 6, 8])) then
59+
print *, "PASS"
60+
else
61+
print *, "FAIL", host
62+
end if
63+
64+
contains
65+
subroutine fill_on_device(ptr)
66+
type(c_ptr) :: ptr
67+
integer, pointer :: p(:)
68+
call c_f_pointer(ptr, p, [n])
69+
70+
!$omp target is_device_ptr(ptr)
71+
p(1) = 2
72+
p(2) = 4
73+
p(3) = 6
74+
p(4) = 8
75+
!$omp end target
76+
end subroutine fill_on_device
77+
end program is_device_ptr_target
78+
79+
!CHECK: PASS

0 commit comments

Comments
 (0)