Skip to content

Commit 2479918

Browse files
committed
[MLIR][OpenMP] Add OpenMPToLLVMIRTranslation support for is_device_ptr
This PR adds support for the OpenMP is_device_ptr clause in the MLIR to LLVM IR translation for target regions. The is_device_ptr clause allows device pointers (allocated via OpenMP runtime APIs) to be used directly in target regions without implicit mapping.
1 parent f5d82d2 commit 2479918

File tree

5 files changed

+82
-17
lines changed

5 files changed

+82
-17
lines changed

flang/test/Integration/OpenMP/map-types-and-sizes.f90

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,15 @@ subroutine mapType_array
3333
!$omp end target
3434
end subroutine mapType_array
3535

36+
!CHECK: @.offload_sizes{{.*}} = private unnamed_addr constant [1 x i64] [i64 8]
37+
!CHECK: @.offload_maptypes{{.*}} = private unnamed_addr constant [1 x i64] [i64 288]
38+
subroutine mapType_is_device_ptr
39+
use iso_c_binding, only : c_ptr
40+
type(c_ptr) :: p
41+
!$omp target is_device_ptr(p)
42+
!$omp end target
43+
end subroutine mapType_is_device_ptr
44+
3645
!CHECK: @.offload_sizes{{.*}} = private unnamed_addr constant [4 x i64] [i64 0, i64 24, i64 8, i64 0]
3746
!CHECK: @.offload_maptypes{{.*}} = private unnamed_addr constant [4 x i64] [i64 32, i64 281474976711169, i64 281474976711171, i64 281474976711187]
3847
subroutine mapType_ptr

mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -332,10 +332,6 @@ static LogicalResult checkImplementationStatus(Operation &op) {
332332
op.getInReductionSyms())
333333
result = todo("in_reduction");
334334
};
335-
auto checkIsDevicePtr = [&todo](auto op, LogicalResult &result) {
336-
if (!op.getIsDevicePtrVars().empty())
337-
result = todo("is_device_ptr");
338-
};
339335
auto checkLinear = [&todo](auto op, LogicalResult &result) {
340336
if (!op.getLinearVars().empty() || !op.getLinearStepVars().empty())
341337
result = todo("linear");
@@ -444,7 +440,6 @@ static LogicalResult checkImplementationStatus(Operation &op) {
444440
checkBare(op, result);
445441
checkDevice(op, result);
446442
checkInReduction(op, result);
447-
checkIsDevicePtr(op, result);
448443
})
449444
.Default([](Operation &) {
450445
// Assume all clauses for an operation can be translated unless they are
@@ -3875,6 +3870,11 @@ convertClauseMapFlags(omp::ClauseMapFlags mlirFlags) {
38753870
if (mapTypeToBool(omp::ClauseMapFlags::attach))
38763871
mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ATTACH;
38773872

3873+
if (mapTypeToBool(omp::ClauseMapFlags::is_device_ptr)) {
3874+
mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;
3875+
mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL;
3876+
}
3877+
38783878
return mapType;
38793879
}
38803880

@@ -3996,6 +3996,9 @@ static void collectMapDataFromMapOperands(
39963996
llvm::Value *origValue = moduleTranslation.lookupValue(offloadPtr);
39973997
auto mapType = convertClauseMapFlags(mapOp.getMapType());
39983998
auto mapTypeAlways = llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS;
3999+
bool isDevicePtr =
4000+
(mapOp.getMapType() & omp::ClauseMapFlags::is_device_ptr) !=
4001+
omp::ClauseMapFlags::none;
39994002

40004003
mapData.OriginalValue.push_back(origValue);
40014004
mapData.BasePointers.push_back(origValue);
@@ -4029,7 +4032,8 @@ static void collectMapDataFromMapOperands(
40294032
mapData.Names.push_back(LLVM::createMappingInformation(
40304033
mapOp.getLoc(), *moduleTranslation.getOpenMPBuilder()));
40314034
mapData.DevicePointers.push_back(
4032-
llvm::OpenMPIRBuilder::DeviceInfoTy::Address);
4035+
isDevicePtr ? llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer
4036+
: llvm::OpenMPIRBuilder::DeviceInfoTy::Address);
40334037
mapData.IsAMapping.push_back(false);
40344038
mapData.IsAMember.push_back(checkIsAMember(hasDevAddrOperands, mapOp));
40354039
}

mlir/test/Target/LLVMIR/omptarget-llvm.mlir

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -622,3 +622,20 @@ module attributes {omp.target_triples = ["amdgcn-amd-amdhsa"]} {
622622
// CHECK: br label %[[VAL_40]]
623623
// CHECK: omp.done: ; preds = %[[VAL_68]], %[[VAL_63]], %[[VAL_32]]
624624
// CHECK: ret void
625+
626+
// -----
627+
628+
module attributes {omp.target_triples = ["amdgcn-amd-amdhsa"]} {
629+
llvm.func @_QPomp_target_is_device_ptr(%arg0 : !llvm.ptr) {
630+
%map = omp.map.info var_ptr(%arg0 : !llvm.ptr, !llvm.ptr)
631+
map_clauses(is_device_ptr) capture(ByRef) -> !llvm.ptr {name = ""}
632+
omp.target map_entries(%map -> %ptr_arg : !llvm.ptr) {
633+
omp.terminator
634+
}
635+
llvm.return
636+
}
637+
}
638+
639+
// CHECK: @.offload_sizes = private unnamed_addr constant [1 x i64] [i64 8]
640+
// CHECK: @.offload_maptypes = private unnamed_addr constant [1 x i64] [i64 288]
641+
// CHECK-LABEL: define void @_QPomp_target_is_device_ptr

mlir/test/Target/LLVMIR/openmp-todo.mlir

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -238,17 +238,6 @@ llvm.func @target_in_reduction(%x : !llvm.ptr) {
238238

239239
// -----
240240

241-
llvm.func @target_is_device_ptr(%x : !llvm.ptr) {
242-
// expected-error@below {{not yet implemented: Unhandled clause is_device_ptr in omp.target operation}}
243-
// expected-error@below {{LLVM Translation failed for operation: omp.target}}
244-
omp.target is_device_ptr(%x : !llvm.ptr) {
245-
omp.terminator
246-
}
247-
llvm.return
248-
}
249-
250-
// -----
251-
252241
llvm.func @target_enter_data_depend(%x: !llvm.ptr) {
253242
// expected-error@below {{not yet implemented: Unhandled clause depend in omp.target_enter_data operation}}
254243
// expected-error@below {{LLVM Translation failed for operation: omp.target_enter_data}}
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
! Validate that a device pointer obtained via omp_get_mapped_ptr can be used
2+
! inside a TARGET region with the is_device_ptr clause.
3+
! REQUIRES: flang, amdgcn-amd-amdhsa
4+
5+
! RUN: %libomptarget-compile-fortran-run-and-check-generic
6+
7+
program is_device_ptr_target
8+
use iso_c_binding, only : c_ptr, c_loc
9+
implicit none
10+
11+
interface
12+
function omp_get_mapped_ptr(host_ptr, device_num) &
13+
bind(C, name="omp_get_mapped_ptr")
14+
use iso_c_binding, only : c_ptr, c_int
15+
type(c_ptr) :: omp_get_mapped_ptr
16+
type(c_ptr), value :: host_ptr
17+
integer(c_int), value :: device_num
18+
end function omp_get_mapped_ptr
19+
end interface
20+
21+
integer, parameter :: n = 4
22+
integer, parameter :: dev = 0
23+
integer, target :: a(n)
24+
type(c_ptr) :: dptr
25+
integer :: flag
26+
27+
a = [2, 4, 6, 8]
28+
flag = 0
29+
30+
!$omp target data map(tofrom: a, flag)
31+
dptr = omp_get_mapped_ptr(c_loc(a), dev)
32+
33+
!$omp target is_device_ptr(dptr) map(tofrom: flag)
34+
flag = flag + 1
35+
!$omp end target
36+
!$omp end target data
37+
38+
if (flag .eq. 1 .and. all(a == [2, 4, 6, 8])) then
39+
print *, "PASS"
40+
else
41+
print *, "FAIL", a
42+
end if
43+
44+
end program is_device_ptr_target
45+
46+
!CHECK: PASS

0 commit comments

Comments
 (0)