diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp index dd0cb3c42ba26..eed965b9aa9f8 100644 --- a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp +++ b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp @@ -44,15 +44,6 @@ mlir::omp::ReductionModifier translateReductionModifier(ReductionModifier mod) { return mlir::omp::ReductionModifier::defaultmod; } -/// Check for unsupported map operand types. -static void checkMapType(mlir::Location location, mlir::Type type) { - if (auto refType = mlir::dyn_cast(type)) - type = refType.getElementType(); - if (auto boxType = mlir::dyn_cast_or_null(type)) - if (!mlir::isa(boxType.getElementType())) - TODO(location, "OMPD_target_data MapOperand BoxType"); -} - static mlir::omp::ScheduleModifier translateScheduleModifier(const omp::clause::Schedule::OrderingModifier &m) { switch (m) { @@ -211,18 +202,6 @@ getIfClauseOperand(lower::AbstractConverter &converter, ifVal); } -static void addUseDeviceClause( - lower::AbstractConverter &converter, const omp::ObjectList &objects, - llvm::SmallVectorImpl &operands, - llvm::SmallVectorImpl &useDeviceSyms) { - genObjectList(objects, converter, operands); - for (mlir::Value &operand : operands) - checkMapType(operand.getLoc(), operand.getType()); - - for (const omp::Object &object : objects) - useDeviceSyms.push_back(object.sym()); -} - //===----------------------------------------------------------------------===// // ClauseProcessor unique clauses //===----------------------------------------------------------------------===// @@ -1225,14 +1204,26 @@ bool ClauseProcessor::processInReduction( } bool ClauseProcessor::processIsDevicePtr( - mlir::omp::IsDevicePtrClauseOps &result, + lower::StatementContext &stmtCtx, mlir::omp::IsDevicePtrClauseOps &result, llvm::SmallVectorImpl &isDeviceSyms) const { - return findRepeatableClause( - [&](const omp::clause::IsDevicePtr &devPtrClause, - const parser::CharBlock &) { - addUseDeviceClause(converter, devPtrClause.v, result.isDevicePtrVars, - isDeviceSyms); + std::map parentMemberIndices; + bool clauseFound = findRepeatableClause( + [&](const omp::clause::IsDevicePtr &clause, + const parser::CharBlock &source) { + mlir::Location location = converter.genLocation(source); + // Force a map so the descriptor is materialized on the device with the + // device address inside. + mlir::omp::ClauseMapFlags mapTypeBits = + mlir::omp::ClauseMapFlags::is_device_ptr | + mlir::omp::ClauseMapFlags::to; + processMapObjects(stmtCtx, location, clause.v, mapTypeBits, + parentMemberIndices, result.isDevicePtrVars, + isDeviceSyms); }); + + insertChildMapInfoIntoParent(converter, semaCtx, stmtCtx, parentMemberIndices, + result.isDevicePtrVars, isDeviceSyms); + return clauseFound; } bool ClauseProcessor::processLinear(mlir::omp::LinearClauseOps &result) const { diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.h b/flang/lib/Lower/OpenMP/ClauseProcessor.h index 54ec9c5f0d752..3485a4ed1581f 100644 --- a/flang/lib/Lower/OpenMP/ClauseProcessor.h +++ b/flang/lib/Lower/OpenMP/ClauseProcessor.h @@ -134,7 +134,7 @@ class ClauseProcessor { mlir::Location currentLocation, mlir::omp::InReductionClauseOps &result, llvm::SmallVectorImpl &outReductionSyms) const; bool processIsDevicePtr( - mlir::omp::IsDevicePtrClauseOps &result, + lower::StatementContext &stmtCtx, mlir::omp::IsDevicePtrClauseOps &result, llvm::SmallVectorImpl &isDeviceSyms) const; bool processLinear(mlir::omp::LinearClauseOps &result) const; bool diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp index 0a200388a36e5..b298fb83f4203 100644 --- a/flang/lib/Lower/OpenMP/OpenMP.cpp +++ b/flang/lib/Lower/OpenMP/OpenMP.cpp @@ -1672,7 +1672,7 @@ static void genTargetClauses( hostEvalInfo->collectValues(clauseOps.hostEvalVars); } cp.processIf(llvm::omp::Directive::OMPD_target, clauseOps); - cp.processIsDevicePtr(clauseOps, isDevicePtrSyms); + cp.processIsDevicePtr(stmtCtx, clauseOps, isDevicePtrSyms); cp.processMap(loc, stmtCtx, clauseOps, llvm::omp::Directive::OMPD_unknown, &mapSyms); cp.processNowait(clauseOps); @@ -2488,13 +2488,15 @@ static bool isDuplicateMappedSymbol( const semantics::Symbol &sym, const llvm::SetVector &privatizedSyms, const llvm::SmallVectorImpl &hasDevSyms, - const llvm::SmallVectorImpl &mappedSyms) { + const llvm::SmallVectorImpl &mappedSyms, + const llvm::SmallVectorImpl &isDevicePtrSyms) { llvm::SmallVector concatSyms; concatSyms.reserve(privatizedSyms.size() + hasDevSyms.size() + - mappedSyms.size()); + mappedSyms.size() + isDevicePtrSyms.size()); concatSyms.append(privatizedSyms.begin(), privatizedSyms.end()); concatSyms.append(hasDevSyms.begin(), hasDevSyms.end()); concatSyms.append(mappedSyms.begin(), mappedSyms.end()); + concatSyms.append(isDevicePtrSyms.begin(), isDevicePtrSyms.end()); auto checkSymbol = [&](const semantics::Symbol &checkSym) { return std::any_of(concatSyms.begin(), concatSyms.end(), @@ -2534,6 +2536,38 @@ genTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable, loc, clauseOps, defaultMaps, hasDeviceAddrSyms, isDevicePtrSyms, mapSyms); + if (!isDevicePtrSyms.empty()) { + // is_device_ptr maps get duplicated so the clause and synthesized + // has_device_addr entry each own a unique MapInfoOp user, keeping + // MapInfoFinalization happy while still wiring the symbol into + // has_device_addr when the user didn’t spell it explicitly. + auto insertionPt = firOpBuilder.saveInsertionPoint(); + auto alreadyPresent = [&](const semantics::Symbol *sym) { + return llvm::any_of(hasDeviceAddrSyms, [&](const semantics::Symbol *s) { + return s && sym && s->GetUltimate() == sym->GetUltimate(); + }); + }; + + for (auto [idx, sym] : llvm::enumerate(isDevicePtrSyms)) { + mlir::Value mapVal = clauseOps.isDevicePtrVars[idx]; + assert(sym && "expected symbol for is_device_ptr"); + assert(mapVal && "expected map value for is_device_ptr"); + auto mapInfo = mapVal.getDefiningOp(); + assert(mapInfo && "expected map info op"); + + if (!alreadyPresent(sym)) { + clauseOps.hasDeviceAddrVars.push_back(mapVal); + hasDeviceAddrSyms.push_back(sym); + } + + firOpBuilder.setInsertionPointAfter(mapInfo); + mlir::Operation *clonedOp = firOpBuilder.clone(*mapInfo.getOperation()); + auto clonedMapInfo = mlir::cast(clonedOp); + clauseOps.isDevicePtrVars[idx] = clonedMapInfo.getResult(); + } + firOpBuilder.restoreInsertionPoint(insertionPt); + } + DataSharingProcessor dsp(converter, semaCtx, item->clauses, eval, /*shouldCollectPreDeterminedSymbols=*/ lower::omp::isLastItemInQueue(item, queue), @@ -2573,7 +2607,7 @@ genTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable, return; if (!isDuplicateMappedSymbol(sym, dsp.getAllSymbolsToPrivatize(), - hasDeviceAddrSyms, mapSyms)) { + hasDeviceAddrSyms, mapSyms, isDevicePtrSyms)) { if (const auto *details = sym.template detailsIf()) converter.copySymbolBinding(details->symbol(), sym); diff --git a/flang/test/Integration/OpenMP/map-types-and-sizes.f90 b/flang/test/Integration/OpenMP/map-types-and-sizes.f90 index 8eb40c089e05c..d6d93985d9895 100644 --- a/flang/test/Integration/OpenMP/map-types-and-sizes.f90 +++ b/flang/test/Integration/OpenMP/map-types-and-sizes.f90 @@ -33,6 +33,15 @@ subroutine mapType_array !$omp end target end subroutine mapType_array +!CHECK: @.offload_sizes{{.*}} = private unnamed_addr constant [1 x i64] [i64 8] +!CHECK: @.offload_maptypes{{.*}} = private unnamed_addr constant [1 x i64] [i64 33] +subroutine mapType_is_device_ptr + use iso_c_binding, only : c_ptr + type(c_ptr) :: p + !$omp target is_device_ptr(p) + !$omp end target +end subroutine mapType_is_device_ptr + !CHECK: @.offload_sizes{{.*}} = private unnamed_addr constant [5 x i64] [i64 0, i64 0, i64 0, i64 8, i64 0] !CHECK: @.offload_maptypes{{.*}} = private unnamed_addr constant [5 x i64] [i64 32, i64 281474976711173, i64 281474976711173, i64 281474976711171, i64 281474976711187] subroutine mapType_ptr diff --git a/flang/test/Lower/OpenMP/target.f90 b/flang/test/Lower/OpenMP/target.f90 index d664dbefe1997..c5d39695e5389 100644 --- a/flang/test/Lower/OpenMP/target.f90 +++ b/flang/test/Lower/OpenMP/target.f90 @@ -566,6 +566,36 @@ subroutine omp_target_device_addr end subroutine omp_target_device_addr +!=============================================================================== +! Target `is_device_ptr` clause +!=============================================================================== + +!CHECK-LABEL: func.func @_QPomp_target_is_device_ptr() { +subroutine omp_target_is_device_ptr + use iso_c_binding, only: c_ptr + implicit none + integer :: i + integer :: arr(4) + type(c_ptr) :: p + + i = 0 + arr = 0 + + !CHECK: %[[P_STORAGE:.*]] = omp.map.info {{.*}}{name = "p"} + !CHECK: %[[P_IS:.*]] = omp.map.info {{.*}}{name = "p"} + !CHECK: %[[ARR_MAP:.*]] = omp.map.info {{.*}}{name = "arr"} + !CHECK: omp.target is_device_ptr(%[[P_IS]] : + !CHECK-SAME: has_device_addr(%[[P_STORAGE]] -> + !CHECK-SAME: map_entries({{.*}}%[[ARR_MAP]] -> + !$omp target is_device_ptr(p) + i = i + 1 + arr(1) = i + !$omp end target + !CHECK: omp.terminator + !CHECK: } +end subroutine omp_target_is_device_ptr + + !=============================================================================== ! Target Data with unstructured code !=============================================================================== diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td index dfdc0b2803763..ea5489faaf4fc 100644 --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td @@ -128,6 +128,7 @@ def ClauseMapFlagsAttachAuto : I32BitEnumAttrCaseBit<"attach_auto", 15>; def ClauseMapFlagsRefPtr : I32BitEnumAttrCaseBit<"ref_ptr", 16>; def ClauseMapFlagsRefPtee : I32BitEnumAttrCaseBit<"ref_ptee", 17>; def ClauseMapFlagsRefPtrPtee : I32BitEnumAttrCaseBit<"ref_ptr_ptee", 18>; +def ClauseMapFlagsIsDevicePtr : I32BitEnumAttrCaseBit<"is_device_ptr", 19>; def ClauseMapFlags : OpenMP_BitEnumAttr< "ClauseMapFlags", @@ -151,7 +152,8 @@ def ClauseMapFlags : OpenMP_BitEnumAttr< ClauseMapFlagsAttachAuto, ClauseMapFlagsRefPtr, ClauseMapFlagsRefPtee, - ClauseMapFlagsRefPtrPtee + ClauseMapFlagsRefPtrPtee, + ClauseMapFlagsIsDevicePtr ]>; def ClauseMapFlagsAttr : OpenMP_EnumAttr !llvm.ptr {name = ""} + omp.target map_entries(%map -> %ptr_arg : !llvm.ptr) { + omp.terminator + } + llvm.return + } +} + +// CHECK: @.offload_sizes = private unnamed_addr constant [1 x i64] [i64 8] +// CHECK: @.offload_maptypes = private unnamed_addr constant [1 x i64] [i64 288] +// CHECK-LABEL: define void @_QPomp_target_is_device_ptr diff --git a/mlir/test/Target/LLVMIR/openmp-todo.mlir b/mlir/test/Target/LLVMIR/openmp-todo.mlir index 731a6322736d4..231cd7aa784fe 100644 --- a/mlir/test/Target/LLVMIR/openmp-todo.mlir +++ b/mlir/test/Target/LLVMIR/openmp-todo.mlir @@ -225,17 +225,6 @@ llvm.func @target_in_reduction(%x : !llvm.ptr) { // ----- -llvm.func @target_is_device_ptr(%x : !llvm.ptr) { - // expected-error@below {{not yet implemented: Unhandled clause is_device_ptr in omp.target operation}} - // expected-error@below {{LLVM Translation failed for operation: omp.target}} - omp.target is_device_ptr(%x : !llvm.ptr) { - omp.terminator - } - llvm.return -} - -// ----- - llvm.func @target_enter_data_depend(%x: !llvm.ptr) { // expected-error@below {{not yet implemented: Unhandled clause depend in omp.target_enter_data operation}} // expected-error@below {{LLVM Translation failed for operation: omp.target_enter_data}} diff --git a/offload/test/offloading/fortran/target-is-device-ptr.f90 b/offload/test/offloading/fortran/target-is-device-ptr.f90 new file mode 100644 index 0000000000000..d6d8c02f50d6a --- /dev/null +++ b/offload/test/offloading/fortran/target-is-device-ptr.f90 @@ -0,0 +1,60 @@ +! Validate that a device pointer obtained via omp_get_mapped_ptr can be used +! inside a TARGET region with the is_device_ptr clause. +! REQUIRES: flang, amdgcn-amd-amdhsa + +! RUN: %libomptarget-compile-fortran-run-and-check-generic + +module mod + implicit none + integer, parameter :: n = 4 +contains + subroutine kernel(dptr) + use iso_c_binding, only : c_ptr, c_f_pointer + implicit none + + type(c_ptr) :: dptr + integer, dimension(:), pointer :: b + integer :: i + + b => null() + + !$omp target is_device_ptr(dptr) + call c_f_pointer(dptr, b, [n]) + do i = 1, n + b(i) = b(i) + 1 + end do + !$omp end target + end subroutine kernel +end module mod + +program is_device_ptr_target + use iso_c_binding, only : c_ptr, c_loc, c_f_pointer + use omp_lib, only: omp_get_default_device, omp_get_mapped_ptr + use mod, only: kernel, n + implicit none + + integer, dimension(n), target :: a + integer :: dev + type(c_ptr) :: dptr + + a = [2, 4, 6, 8] + print '("BEFORE:", I3)', a + + dev = omp_get_default_device() + + !$omp target data map(tofrom: a) + dptr = omp_get_mapped_ptr(c_loc(a), dev) + call kernel(dptr) + !$omp end target data + + print '("AFTER: ", I3)', a + + if (all(a == [3, 5, 7, 9])) then + print '("PASS")' + else + print '("FAIL ", I3)', a + end if + +end program is_device_ptr_target + +!CHECK: PASS