Skip to content

Commit f5d82d2

Browse files
committed
[Flang][OpenMP] Add lowering support for is_device_ptr clause
Add support for OpenMP is_device_ptr clause for target directives.
1 parent 0621fd0 commit f5d82d2

File tree

6 files changed

+95
-33
lines changed

6 files changed

+95
-33
lines changed

flang/lib/Lower/OpenMP/ClauseProcessor.cpp

Lines changed: 15 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -42,15 +42,6 @@ mlir::omp::ReductionModifier translateReductionModifier(ReductionModifier mod) {
4242
return mlir::omp::ReductionModifier::defaultmod;
4343
}
4444

45-
/// Check for unsupported map operand types.
46-
static void checkMapType(mlir::Location location, mlir::Type type) {
47-
if (auto refType = mlir::dyn_cast<fir::ReferenceType>(type))
48-
type = refType.getElementType();
49-
if (auto boxType = mlir::dyn_cast_or_null<fir::BoxType>(type))
50-
if (!mlir::isa<fir::PointerType>(boxType.getElementType()))
51-
TODO(location, "OMPD_target_data MapOperand BoxType");
52-
}
53-
5445
static mlir::omp::ScheduleModifier
5546
translateScheduleModifier(const omp::clause::Schedule::OrderingModifier &m) {
5647
switch (m) {
@@ -209,18 +200,6 @@ getIfClauseOperand(lower::AbstractConverter &converter,
209200
ifVal);
210201
}
211202

212-
static void addUseDeviceClause(
213-
lower::AbstractConverter &converter, const omp::ObjectList &objects,
214-
llvm::SmallVectorImpl<mlir::Value> &operands,
215-
llvm::SmallVectorImpl<const semantics::Symbol *> &useDeviceSyms) {
216-
genObjectList(objects, converter, operands);
217-
for (mlir::Value &operand : operands)
218-
checkMapType(operand.getLoc(), operand.getType());
219-
220-
for (const omp::Object &object : objects)
221-
useDeviceSyms.push_back(object.sym());
222-
}
223-
224203
//===----------------------------------------------------------------------===//
225204
// ClauseProcessor unique clauses
226205
//===----------------------------------------------------------------------===//
@@ -1159,14 +1138,23 @@ bool ClauseProcessor::processInReduction(
11591138
}
11601139

11611140
bool ClauseProcessor::processIsDevicePtr(
1162-
mlir::omp::IsDevicePtrClauseOps &result,
1141+
lower::StatementContext &stmtCtx, mlir::omp::IsDevicePtrClauseOps &result,
11631142
llvm::SmallVectorImpl<const semantics::Symbol *> &isDeviceSyms) const {
1164-
return findRepeatableClause<omp::clause::IsDevicePtr>(
1165-
[&](const omp::clause::IsDevicePtr &devPtrClause,
1166-
const parser::CharBlock &) {
1167-
addUseDeviceClause(converter, devPtrClause.v, result.isDevicePtrVars,
1168-
isDeviceSyms);
1143+
std::map<Object, OmpMapParentAndMemberData> parentMemberIndices;
1144+
bool clauseFound = findRepeatableClause<omp::clause::IsDevicePtr>(
1145+
[&](const omp::clause::IsDevicePtr &clause,
1146+
const parser::CharBlock &source) {
1147+
mlir::Location location = converter.genLocation(source);
1148+
mlir::omp::ClauseMapFlags mapTypeBits =
1149+
mlir::omp::ClauseMapFlags::is_device_ptr;
1150+
processMapObjects(stmtCtx, location, clause.v, mapTypeBits,
1151+
parentMemberIndices, result.isDevicePtrVars,
1152+
isDeviceSyms);
11691153
});
1154+
1155+
insertChildMapInfoIntoParent(converter, semaCtx, stmtCtx, parentMemberIndices,
1156+
result.isDevicePtrVars, isDeviceSyms);
1157+
return clauseFound;
11701158
}
11711159

11721160
bool ClauseProcessor::processLinear(mlir::omp::LinearClauseOps &result) const {

flang/lib/Lower/OpenMP/ClauseProcessor.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ class ClauseProcessor {
130130
mlir::Location currentLocation, mlir::omp::InReductionClauseOps &result,
131131
llvm::SmallVectorImpl<const semantics::Symbol *> &outReductionSyms) const;
132132
bool processIsDevicePtr(
133-
mlir::omp::IsDevicePtrClauseOps &result,
133+
lower::StatementContext &stmtCtx, mlir::omp::IsDevicePtrClauseOps &result,
134134
llvm::SmallVectorImpl<const semantics::Symbol *> &isDeviceSyms) const;
135135
bool processLinear(mlir::omp::LinearClauseOps &result) const;
136136
bool

flang/lib/Lower/OpenMP/OpenMP.cpp

Lines changed: 41 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1673,7 +1673,7 @@ static void genTargetClauses(
16731673
hostEvalInfo->collectValues(clauseOps.hostEvalVars);
16741674
}
16751675
cp.processIf(llvm::omp::Directive::OMPD_target, clauseOps);
1676-
cp.processIsDevicePtr(clauseOps, isDevicePtrSyms);
1676+
cp.processIsDevicePtr(stmtCtx, clauseOps, isDevicePtrSyms);
16771677
cp.processMap(loc, stmtCtx, clauseOps, llvm::omp::Directive::OMPD_unknown,
16781678
&mapSyms);
16791679
cp.processNowait(clauseOps);
@@ -2485,13 +2485,15 @@ static bool isDuplicateMappedSymbol(
24852485
const semantics::Symbol &sym,
24862486
const llvm::SetVector<const semantics::Symbol *> &privatizedSyms,
24872487
const llvm::SmallVectorImpl<const semantics::Symbol *> &hasDevSyms,
2488-
const llvm::SmallVectorImpl<const semantics::Symbol *> &mappedSyms) {
2488+
const llvm::SmallVectorImpl<const semantics::Symbol *> &mappedSyms,
2489+
const llvm::SmallVectorImpl<const semantics::Symbol *> &isDevicePtrSyms) {
24892490
llvm::SmallVector<const semantics::Symbol *> concatSyms;
24902491
concatSyms.reserve(privatizedSyms.size() + hasDevSyms.size() +
2491-
mappedSyms.size());
2492+
mappedSyms.size() + isDevicePtrSyms.size());
24922493
concatSyms.append(privatizedSyms.begin(), privatizedSyms.end());
24932494
concatSyms.append(hasDevSyms.begin(), hasDevSyms.end());
24942495
concatSyms.append(mappedSyms.begin(), mappedSyms.end());
2496+
concatSyms.append(isDevicePtrSyms.begin(), isDevicePtrSyms.end());
24952497

24962498
auto checkSymbol = [&](const semantics::Symbol &checkSym) {
24972499
return std::any_of(concatSyms.begin(), concatSyms.end(),
@@ -2531,6 +2533,41 @@ genTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
25312533
loc, clauseOps, defaultMaps, hasDeviceAddrSyms,
25322534
isDevicePtrSyms, mapSyms);
25332535

2536+
if (!isDevicePtrSyms.empty()) {
2537+
// is_device_ptr maps get duplicated so the clause and synthesized
2538+
// has_device_addr entry each own a unique MapInfoOp user, keeping
2539+
// MapInfoFinalization happy while still wiring the symbol into
2540+
// has_device_addr when the user didn’t spell it explicitly.
2541+
fir::FirOpBuilder &builder = converter.getFirOpBuilder();
2542+
auto insertionPt = builder.saveInsertionPoint();
2543+
auto alreadyPresent = [&](const semantics::Symbol *sym) {
2544+
return llvm::any_of(hasDeviceAddrSyms, [&](const semantics::Symbol *s) {
2545+
return s && sym && s->GetUltimate() == sym->GetUltimate();
2546+
});
2547+
};
2548+
2549+
for (auto [idx, sym] : llvm::enumerate(isDevicePtrSyms)) {
2550+
mlir::Value mapVal = clauseOps.isDevicePtrVars[idx];
2551+
if (!sym || !mapVal)
2552+
continue;
2553+
auto mapInfo = mapVal.getDefiningOp<mlir::omp::MapInfoOp>();
2554+
if (!mapInfo)
2555+
continue;
2556+
2557+
if (!alreadyPresent(sym)) {
2558+
clauseOps.hasDeviceAddrVars.push_back(mapVal);
2559+
hasDeviceAddrSyms.push_back(sym);
2560+
}
2561+
2562+
builder.setInsertionPointAfter(mapInfo);
2563+
auto clonedOp = builder.clone(*mapInfo.getOperation());
2564+
auto clonedMapInfo = mlir::dyn_cast<mlir::omp::MapInfoOp>(clonedOp);
2565+
assert(clonedMapInfo && "expected cloned map info op");
2566+
clauseOps.isDevicePtrVars[idx] = clonedMapInfo.getResult();
2567+
}
2568+
builder.restoreInsertionPoint(insertionPt);
2569+
}
2570+
25342571
DataSharingProcessor dsp(converter, semaCtx, item->clauses, eval,
25352572
/*shouldCollectPreDeterminedSymbols=*/
25362573
lower::omp::isLastItemInQueue(item, queue),
@@ -2570,7 +2607,7 @@ genTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
25702607
return;
25712608

25722609
if (!isDuplicateMappedSymbol(sym, dsp.getAllSymbolsToPrivatize(),
2573-
hasDeviceAddrSyms, mapSyms)) {
2610+
hasDeviceAddrSyms, mapSyms, isDevicePtrSyms)) {
25742611
if (const auto *details =
25752612
sym.template detailsIf<semantics::HostAssocDetails>())
25762613
converter.copySymbolBinding(details->symbol(), sym);

flang/test/Lower/OpenMP/target.f90

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -566,6 +566,36 @@ subroutine omp_target_device_addr
566566
end subroutine omp_target_device_addr
567567

568568

569+
!===============================================================================
570+
! Target `is_device_ptr` clause
571+
!===============================================================================
572+
573+
!CHECK-LABEL: func.func @_QPomp_target_is_device_ptr() {
574+
subroutine omp_target_is_device_ptr
575+
use iso_c_binding, only: c_associated, c_ptr
576+
implicit none
577+
integer :: i
578+
integer :: arr(4)
579+
type(c_ptr) :: p
580+
581+
i = 0
582+
arr = 0
583+
584+
!CHECK: %[[P_STORAGE:.*]] = omp.map.info {{.*}}{name = "p"}
585+
!CHECK: %[[P_IS:.*]] = omp.map.info {{.*}}{name = "p"}
586+
!CHECK: %[[ARR_MAP:.*]] = omp.map.info {{.*}}{name = "arr"}
587+
!CHECK: omp.target is_device_ptr(%[[P_IS]] :
588+
!CHECK-SAME: has_device_addr(%[[P_STORAGE]] ->
589+
!CHECK-SAME: map_entries({{.*}}%[[ARR_MAP]] ->
590+
!$omp target is_device_ptr(p)
591+
if (c_associated(p)) i = i + 1
592+
arr(1) = i
593+
!$omp end target
594+
!CHECK: omp.terminator
595+
!CHECK: }
596+
end subroutine omp_target_is_device_ptr
597+
598+
569599
!===============================================================================
570600
! Target Data with unstructured code
571601
!===============================================================================

mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@ def ClauseMapFlagsAttachAuto : I32BitEnumAttrCaseBit<"attach_auto", 15>;
126126
def ClauseMapFlagsRefPtr : I32BitEnumAttrCaseBit<"ref_ptr", 16>;
127127
def ClauseMapFlagsRefPtee : I32BitEnumAttrCaseBit<"ref_ptee", 17>;
128128
def ClauseMapFlagsRefPtrPtee : I32BitEnumAttrCaseBit<"ref_ptr_ptee", 18>;
129+
def ClauseMapFlagsIsDevicePtr : I32BitEnumAttrCaseBit<"is_device_ptr", 19>;
129130

130131
def ClauseMapFlags : OpenMP_BitEnumAttr<
131132
"ClauseMapFlags",
@@ -149,7 +150,8 @@ def ClauseMapFlags : OpenMP_BitEnumAttr<
149150
ClauseMapFlagsAttachAuto,
150151
ClauseMapFlagsRefPtr,
151152
ClauseMapFlagsRefPtee,
152-
ClauseMapFlagsRefPtrPtee
153+
ClauseMapFlagsRefPtrPtee,
154+
ClauseMapFlagsIsDevicePtr
153155
]>;
154156

155157
def ClauseMapFlagsAttr : OpenMP_EnumAttr<ClauseMapFlags,

mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1817,6 +1817,9 @@ static ParseResult parseMapClause(OpAsmParser &parser,
18171817
if (mapTypeMod == "ref_ptr_ptee")
18181818
mapTypeBits |= ClauseMapFlags::ref_ptr_ptee;
18191819

1820+
if (mapTypeMod == "is_device_ptr")
1821+
mapTypeBits |= ClauseMapFlags::is_device_ptr;
1822+
18201823
return success();
18211824
};
18221825

@@ -1886,6 +1889,8 @@ static void printMapClause(OpAsmPrinter &p, Operation *op,
18861889
mapTypeStrs.push_back("ref_ptee");
18871890
if (mapTypeToBool(mapFlags, ClauseMapFlags::ref_ptr_ptee))
18881891
mapTypeStrs.push_back("ref_ptr_ptee");
1892+
if (mapTypeToBool(mapFlags, ClauseMapFlags::is_device_ptr))
1893+
mapTypeStrs.push_back("is_device_ptr");
18891894
if (mapFlags == ClauseMapFlags::none)
18901895
mapTypeStrs.push_back("none");
18911896

0 commit comments

Comments
 (0)