Skip to content

Commit a77c494

Browse files
authored
[Flang][OpenMP] Add lowering support for is_device_ptr clause (#169331)
Add support for OpenMP is_device_ptr clause for target directives. [MLIR][OpenMP] Add OpenMPToLLVMIRTranslation support for is_device_ptr #169367 This PR adds support for the OpenMP is_device_ptr clause in the MLIR to LLVM IR translation for target regions. The is_device_ptr clause allows device pointers (allocated via OpenMP runtime APIs) to be used directly in target regions without implicit mapping.
1 parent 68caa8b commit a77c494

File tree

11 files changed

+199
-51
lines changed

11 files changed

+199
-51
lines changed

flang/lib/Lower/OpenMP/ClauseProcessor.cpp

Lines changed: 18 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -44,15 +44,6 @@ mlir::omp::ReductionModifier translateReductionModifier(ReductionModifier mod) {
4444
return mlir::omp::ReductionModifier::defaultmod;
4545
}
4646

47-
/// Check for unsupported map operand types.
48-
static void checkMapType(mlir::Location location, mlir::Type type) {
49-
if (auto refType = mlir::dyn_cast<fir::ReferenceType>(type))
50-
type = refType.getElementType();
51-
if (auto boxType = mlir::dyn_cast_or_null<fir::BoxType>(type))
52-
if (!mlir::isa<fir::PointerType>(boxType.getElementType()))
53-
TODO(location, "OMPD_target_data MapOperand BoxType");
54-
}
55-
5647
static mlir::omp::ScheduleModifier
5748
translateScheduleModifier(const omp::clause::Schedule::OrderingModifier &m) {
5849
switch (m) {
@@ -211,18 +202,6 @@ getIfClauseOperand(lower::AbstractConverter &converter,
211202
ifVal);
212203
}
213204

214-
static void addUseDeviceClause(
215-
lower::AbstractConverter &converter, const omp::ObjectList &objects,
216-
llvm::SmallVectorImpl<mlir::Value> &operands,
217-
llvm::SmallVectorImpl<const semantics::Symbol *> &useDeviceSyms) {
218-
genObjectList(objects, converter, operands);
219-
for (mlir::Value &operand : operands)
220-
checkMapType(operand.getLoc(), operand.getType());
221-
222-
for (const omp::Object &object : objects)
223-
useDeviceSyms.push_back(object.sym());
224-
}
225-
226205
//===----------------------------------------------------------------------===//
227206
// ClauseProcessor unique clauses
228207
//===----------------------------------------------------------------------===//
@@ -1225,14 +1204,26 @@ bool ClauseProcessor::processInReduction(
12251204
}
12261205

12271206
bool ClauseProcessor::processIsDevicePtr(
1228-
mlir::omp::IsDevicePtrClauseOps &result,
1207+
lower::StatementContext &stmtCtx, mlir::omp::IsDevicePtrClauseOps &result,
12291208
llvm::SmallVectorImpl<const semantics::Symbol *> &isDeviceSyms) const {
1230-
return findRepeatableClause<omp::clause::IsDevicePtr>(
1231-
[&](const omp::clause::IsDevicePtr &devPtrClause,
1232-
const parser::CharBlock &) {
1233-
addUseDeviceClause(converter, devPtrClause.v, result.isDevicePtrVars,
1234-
isDeviceSyms);
1209+
std::map<Object, OmpMapParentAndMemberData> parentMemberIndices;
1210+
bool clauseFound = findRepeatableClause<omp::clause::IsDevicePtr>(
1211+
[&](const omp::clause::IsDevicePtr &clause,
1212+
const parser::CharBlock &source) {
1213+
mlir::Location location = converter.genLocation(source);
1214+
// Force a map so the descriptor is materialized on the device with the
1215+
// device address inside.
1216+
mlir::omp::ClauseMapFlags mapTypeBits =
1217+
mlir::omp::ClauseMapFlags::is_device_ptr |
1218+
mlir::omp::ClauseMapFlags::to;
1219+
processMapObjects(stmtCtx, location, clause.v, mapTypeBits,
1220+
parentMemberIndices, result.isDevicePtrVars,
1221+
isDeviceSyms);
12351222
});
1223+
1224+
insertChildMapInfoIntoParent(converter, semaCtx, stmtCtx, parentMemberIndices,
1225+
result.isDevicePtrVars, isDeviceSyms);
1226+
return clauseFound;
12361227
}
12371228

12381229
bool ClauseProcessor::processLinear(mlir::omp::LinearClauseOps &result) const {

flang/lib/Lower/OpenMP/ClauseProcessor.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ class ClauseProcessor {
134134
mlir::Location currentLocation, mlir::omp::InReductionClauseOps &result,
135135
llvm::SmallVectorImpl<const semantics::Symbol *> &outReductionSyms) const;
136136
bool processIsDevicePtr(
137-
mlir::omp::IsDevicePtrClauseOps &result,
137+
lower::StatementContext &stmtCtx, mlir::omp::IsDevicePtrClauseOps &result,
138138
llvm::SmallVectorImpl<const semantics::Symbol *> &isDeviceSyms) const;
139139
bool processLinear(mlir::omp::LinearClauseOps &result) const;
140140
bool

flang/lib/Lower/OpenMP/OpenMP.cpp

Lines changed: 38 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1672,7 +1672,7 @@ static void genTargetClauses(
16721672
hostEvalInfo->collectValues(clauseOps.hostEvalVars);
16731673
}
16741674
cp.processIf(llvm::omp::Directive::OMPD_target, clauseOps);
1675-
cp.processIsDevicePtr(clauseOps, isDevicePtrSyms);
1675+
cp.processIsDevicePtr(stmtCtx, clauseOps, isDevicePtrSyms);
16761676
cp.processMap(loc, stmtCtx, clauseOps, llvm::omp::Directive::OMPD_unknown,
16771677
&mapSyms);
16781678
cp.processNowait(clauseOps);
@@ -2488,13 +2488,15 @@ static bool isDuplicateMappedSymbol(
24882488
const semantics::Symbol &sym,
24892489
const llvm::SetVector<const semantics::Symbol *> &privatizedSyms,
24902490
const llvm::SmallVectorImpl<const semantics::Symbol *> &hasDevSyms,
2491-
const llvm::SmallVectorImpl<const semantics::Symbol *> &mappedSyms) {
2491+
const llvm::SmallVectorImpl<const semantics::Symbol *> &mappedSyms,
2492+
const llvm::SmallVectorImpl<const semantics::Symbol *> &isDevicePtrSyms) {
24922493
llvm::SmallVector<const semantics::Symbol *> concatSyms;
24932494
concatSyms.reserve(privatizedSyms.size() + hasDevSyms.size() +
2494-
mappedSyms.size());
2495+
mappedSyms.size() + isDevicePtrSyms.size());
24952496
concatSyms.append(privatizedSyms.begin(), privatizedSyms.end());
24962497
concatSyms.append(hasDevSyms.begin(), hasDevSyms.end());
24972498
concatSyms.append(mappedSyms.begin(), mappedSyms.end());
2499+
concatSyms.append(isDevicePtrSyms.begin(), isDevicePtrSyms.end());
24982500

24992501
auto checkSymbol = [&](const semantics::Symbol &checkSym) {
25002502
return std::any_of(concatSyms.begin(), concatSyms.end(),
@@ -2534,6 +2536,38 @@ genTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
25342536
loc, clauseOps, defaultMaps, hasDeviceAddrSyms,
25352537
isDevicePtrSyms, mapSyms);
25362538

2539+
if (!isDevicePtrSyms.empty()) {
2540+
// is_device_ptr maps get duplicated so the clause and synthesized
2541+
// has_device_addr entry each own a unique MapInfoOp user, keeping
2542+
// MapInfoFinalization happy while still wiring the symbol into
2543+
// has_device_addr when the user didn’t spell it explicitly.
2544+
auto insertionPt = firOpBuilder.saveInsertionPoint();
2545+
auto alreadyPresent = [&](const semantics::Symbol *sym) {
2546+
return llvm::any_of(hasDeviceAddrSyms, [&](const semantics::Symbol *s) {
2547+
return s && sym && s->GetUltimate() == sym->GetUltimate();
2548+
});
2549+
};
2550+
2551+
for (auto [idx, sym] : llvm::enumerate(isDevicePtrSyms)) {
2552+
mlir::Value mapVal = clauseOps.isDevicePtrVars[idx];
2553+
assert(sym && "expected symbol for is_device_ptr");
2554+
assert(mapVal && "expected map value for is_device_ptr");
2555+
auto mapInfo = mapVal.getDefiningOp<mlir::omp::MapInfoOp>();
2556+
assert(mapInfo && "expected map info op");
2557+
2558+
if (!alreadyPresent(sym)) {
2559+
clauseOps.hasDeviceAddrVars.push_back(mapVal);
2560+
hasDeviceAddrSyms.push_back(sym);
2561+
}
2562+
2563+
firOpBuilder.setInsertionPointAfter(mapInfo);
2564+
mlir::Operation *clonedOp = firOpBuilder.clone(*mapInfo.getOperation());
2565+
auto clonedMapInfo = mlir::cast<mlir::omp::MapInfoOp>(clonedOp);
2566+
clauseOps.isDevicePtrVars[idx] = clonedMapInfo.getResult();
2567+
}
2568+
firOpBuilder.restoreInsertionPoint(insertionPt);
2569+
}
2570+
25372571
DataSharingProcessor dsp(converter, semaCtx, item->clauses, eval,
25382572
/*shouldCollectPreDeterminedSymbols=*/
25392573
lower::omp::isLastItemInQueue(item, queue),
@@ -2573,7 +2607,7 @@ genTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
25732607
return;
25742608

25752609
if (!isDuplicateMappedSymbol(sym, dsp.getAllSymbolsToPrivatize(),
2576-
hasDeviceAddrSyms, mapSyms)) {
2610+
hasDeviceAddrSyms, mapSyms, isDevicePtrSyms)) {
25772611
if (const auto *details =
25782612
sym.template detailsIf<semantics::HostAssocDetails>())
25792613
converter.copySymbolBinding(details->symbol(), sym);

flang/test/Integration/OpenMP/map-types-and-sizes.f90

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,15 @@ subroutine mapType_array
3333
!$omp end target
3434
end subroutine mapType_array
3535

36+
!CHECK: @.offload_sizes{{.*}} = private unnamed_addr constant [1 x i64] [i64 8]
37+
!CHECK: @.offload_maptypes{{.*}} = private unnamed_addr constant [1 x i64] [i64 33]
38+
subroutine mapType_is_device_ptr
39+
use iso_c_binding, only : c_ptr
40+
type(c_ptr) :: p
41+
!$omp target is_device_ptr(p)
42+
!$omp end target
43+
end subroutine mapType_is_device_ptr
44+
3645
!CHECK: @.offload_sizes{{.*}} = private unnamed_addr constant [5 x i64] [i64 0, i64 0, i64 0, i64 8, i64 0]
3746
!CHECK: @.offload_maptypes{{.*}} = private unnamed_addr constant [5 x i64] [i64 32, i64 281474976711173, i64 281474976711173, i64 281474976711171, i64 281474976711187]
3847
subroutine mapType_ptr

flang/test/Lower/OpenMP/target.f90

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -566,6 +566,36 @@ subroutine omp_target_device_addr
566566
end subroutine omp_target_device_addr
567567

568568

569+
!===============================================================================
570+
! Target `is_device_ptr` clause
571+
!===============================================================================
572+
573+
!CHECK-LABEL: func.func @_QPomp_target_is_device_ptr() {
574+
subroutine omp_target_is_device_ptr
575+
use iso_c_binding, only: c_ptr
576+
implicit none
577+
integer :: i
578+
integer :: arr(4)
579+
type(c_ptr) :: p
580+
581+
i = 0
582+
arr = 0
583+
584+
!CHECK: %[[P_STORAGE:.*]] = omp.map.info {{.*}}{name = "p"}
585+
!CHECK: %[[P_IS:.*]] = omp.map.info {{.*}}{name = "p"}
586+
!CHECK: %[[ARR_MAP:.*]] = omp.map.info {{.*}}{name = "arr"}
587+
!CHECK: omp.target is_device_ptr(%[[P_IS]] :
588+
!CHECK-SAME: has_device_addr(%[[P_STORAGE]] ->
589+
!CHECK-SAME: map_entries({{.*}}%[[ARR_MAP]] ->
590+
!$omp target is_device_ptr(p)
591+
i = i + 1
592+
arr(1) = i
593+
!$omp end target
594+
!CHECK: omp.terminator
595+
!CHECK: }
596+
end subroutine omp_target_is_device_ptr
597+
598+
569599
!===============================================================================
570600
! Target Data with unstructured code
571601
!===============================================================================

mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,7 @@ def ClauseMapFlagsAttachAuto : I32BitEnumAttrCaseBit<"attach_auto", 15>;
128128
def ClauseMapFlagsRefPtr : I32BitEnumAttrCaseBit<"ref_ptr", 16>;
129129
def ClauseMapFlagsRefPtee : I32BitEnumAttrCaseBit<"ref_ptee", 17>;
130130
def ClauseMapFlagsRefPtrPtee : I32BitEnumAttrCaseBit<"ref_ptr_ptee", 18>;
131+
def ClauseMapFlagsIsDevicePtr : I32BitEnumAttrCaseBit<"is_device_ptr", 19>;
131132

132133
def ClauseMapFlags : OpenMP_BitEnumAttr<
133134
"ClauseMapFlags",
@@ -151,7 +152,8 @@ def ClauseMapFlags : OpenMP_BitEnumAttr<
151152
ClauseMapFlagsAttachAuto,
152153
ClauseMapFlagsRefPtr,
153154
ClauseMapFlagsRefPtee,
154-
ClauseMapFlagsRefPtrPtee
155+
ClauseMapFlagsRefPtrPtee,
156+
ClauseMapFlagsIsDevicePtr
155157
]>;
156158

157159
def ClauseMapFlagsAttr : OpenMP_EnumAttr<ClauseMapFlags,

mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1818,6 +1818,9 @@ static ParseResult parseMapClause(OpAsmParser &parser,
18181818
if (mapTypeMod == "ref_ptr_ptee")
18191819
mapTypeBits |= ClauseMapFlags::ref_ptr_ptee;
18201820

1821+
if (mapTypeMod == "is_device_ptr")
1822+
mapTypeBits |= ClauseMapFlags::is_device_ptr;
1823+
18211824
return success();
18221825
};
18231826

@@ -1887,6 +1890,8 @@ static void printMapClause(OpAsmPrinter &p, Operation *op,
18871890
mapTypeStrs.push_back("ref_ptee");
18881891
if (mapTypeToBool(mapFlags, ClauseMapFlags::ref_ptr_ptee))
18891892
mapTypeStrs.push_back("ref_ptr_ptee");
1893+
if (mapTypeToBool(mapFlags, ClauseMapFlags::is_device_ptr))
1894+
mapTypeStrs.push_back("is_device_ptr");
18901895
if (mapFlags == ClauseMapFlags::none)
18911896
mapTypeStrs.push_back("none");
18921897

mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -330,10 +330,6 @@ static LogicalResult checkImplementationStatus(Operation &op) {
330330
op.getInReductionSyms())
331331
result = todo("in_reduction");
332332
};
333-
auto checkIsDevicePtr = [&todo](auto op, LogicalResult &result) {
334-
if (!op.getIsDevicePtrVars().empty())
335-
result = todo("is_device_ptr");
336-
};
337333
auto checkLinear = [&todo](auto op, LogicalResult &result) {
338334
if (!op.getLinearVars().empty() || !op.getLinearStepVars().empty())
339335
result = todo("linear");
@@ -441,7 +437,6 @@ static LogicalResult checkImplementationStatus(Operation &op) {
441437
checkBare(op, result);
442438
checkDevice(op, result);
443439
checkInReduction(op, result);
444-
checkIsDevicePtr(op, result);
445440
})
446441
.Default([](Operation &) {
447442
// Assume all clauses for an operation can be translated unless they are
@@ -3959,6 +3954,9 @@ convertClauseMapFlags(omp::ClauseMapFlags mlirFlags) {
39593954
auto mapTypeToBool = [&mlirFlags](omp::ClauseMapFlags flag) {
39603955
return (mlirFlags & flag) == flag;
39613956
};
3957+
const bool hasExplicitMap =
3958+
(mlirFlags & ~omp::ClauseMapFlags::is_device_ptr) !=
3959+
omp::ClauseMapFlags::none;
39623960

39633961
llvm::omp::OpenMPOffloadMappingFlags mapType =
39643962
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE;
@@ -3999,6 +3997,12 @@ convertClauseMapFlags(omp::ClauseMapFlags mlirFlags) {
39993997
if (mapTypeToBool(omp::ClauseMapFlags::attach))
40003998
mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ATTACH;
40013999

4000+
if (mapTypeToBool(omp::ClauseMapFlags::is_device_ptr)) {
4001+
mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;
4002+
if (!hasExplicitMap)
4003+
mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL;
4004+
}
4005+
40024006
return mapType;
40034007
}
40044008

@@ -4122,6 +4126,9 @@ static void collectMapDataFromMapOperands(
41224126
llvm::Value *origValue = moduleTranslation.lookupValue(offloadPtr);
41234127
auto mapType = convertClauseMapFlags(mapOp.getMapType());
41244128
auto mapTypeAlways = llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS;
4129+
bool isDevicePtr =
4130+
(mapOp.getMapType() & omp::ClauseMapFlags::is_device_ptr) !=
4131+
omp::ClauseMapFlags::none;
41254132

41264133
mapData.OriginalValue.push_back(origValue);
41274134
mapData.BasePointers.push_back(origValue);
@@ -4148,14 +4155,18 @@ static void collectMapDataFromMapOperands(
41484155
mapData.Mappers.push_back(nullptr);
41494156
}
41504157
} else {
4158+
// For is_device_ptr we need the map type to propagate so the runtime
4159+
// can materialize the device-side copy of the pointer container.
41514160
mapData.Types.push_back(
4152-
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL);
4161+
isDevicePtr ? mapType
4162+
: llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL);
41534163
mapData.Mappers.push_back(nullptr);
41544164
}
41554165
mapData.Names.push_back(LLVM::createMappingInformation(
41564166
mapOp.getLoc(), *moduleTranslation.getOpenMPBuilder()));
41574167
mapData.DevicePointers.push_back(
4158-
llvm::OpenMPIRBuilder::DeviceInfoTy::Address);
4168+
isDevicePtr ? llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer
4169+
: llvm::OpenMPIRBuilder::DeviceInfoTy::Address);
41594170
mapData.IsAMapping.push_back(false);
41604171
mapData.IsAMember.push_back(checkIsAMember(hasDevAddrOperands, mapOp));
41614172
}

mlir/test/Target/LLVMIR/omptarget-llvm.mlir

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -622,3 +622,20 @@ module attributes {omp.target_triples = ["amdgcn-amd-amdhsa"]} {
622622
// CHECK: br label %[[VAL_40]]
623623
// CHECK: omp.done: ; preds = %[[VAL_68]], %[[VAL_63]], %[[VAL_32]]
624624
// CHECK: ret void
625+
626+
// -----
627+
628+
module attributes {omp.target_triples = ["amdgcn-amd-amdhsa"]} {
629+
llvm.func @_QPomp_target_is_device_ptr(%arg0 : !llvm.ptr) {
630+
%map = omp.map.info var_ptr(%arg0 : !llvm.ptr, !llvm.ptr)
631+
map_clauses(is_device_ptr) capture(ByRef) -> !llvm.ptr {name = ""}
632+
omp.target map_entries(%map -> %ptr_arg : !llvm.ptr) {
633+
omp.terminator
634+
}
635+
llvm.return
636+
}
637+
}
638+
639+
// CHECK: @.offload_sizes = private unnamed_addr constant [1 x i64] [i64 8]
640+
// CHECK: @.offload_maptypes = private unnamed_addr constant [1 x i64] [i64 288]
641+
// CHECK-LABEL: define void @_QPomp_target_is_device_ptr

mlir/test/Target/LLVMIR/openmp-todo.mlir

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -225,17 +225,6 @@ llvm.func @target_in_reduction(%x : !llvm.ptr) {
225225

226226
// -----
227227

228-
llvm.func @target_is_device_ptr(%x : !llvm.ptr) {
229-
// expected-error@below {{not yet implemented: Unhandled clause is_device_ptr in omp.target operation}}
230-
// expected-error@below {{LLVM Translation failed for operation: omp.target}}
231-
omp.target is_device_ptr(%x : !llvm.ptr) {
232-
omp.terminator
233-
}
234-
llvm.return
235-
}
236-
237-
// -----
238-
239228
llvm.func @target_enter_data_depend(%x: !llvm.ptr) {
240229
// expected-error@below {{not yet implemented: Unhandled clause depend in omp.target_enter_data operation}}
241230
// expected-error@below {{LLVM Translation failed for operation: omp.target_enter_data}}

0 commit comments

Comments
 (0)