Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 15 additions & 27 deletions flang/lib/Lower/OpenMP/ClauseProcessor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,6 @@ mlir::omp::ReductionModifier translateReductionModifier(ReductionModifier mod) {
return mlir::omp::ReductionModifier::defaultmod;
}

/// Check for unsupported map operand types.
static void checkMapType(mlir::Location location, mlir::Type type) {
if (auto refType = mlir::dyn_cast<fir::ReferenceType>(type))
type = refType.getElementType();
if (auto boxType = mlir::dyn_cast_or_null<fir::BoxType>(type))
if (!mlir::isa<fir::PointerType>(boxType.getElementType()))
TODO(location, "OMPD_target_data MapOperand BoxType");
}

static mlir::omp::ScheduleModifier
translateScheduleModifier(const omp::clause::Schedule::OrderingModifier &m) {
switch (m) {
Expand Down Expand Up @@ -209,18 +200,6 @@ getIfClauseOperand(lower::AbstractConverter &converter,
ifVal);
}

static void addUseDeviceClause(
lower::AbstractConverter &converter, const omp::ObjectList &objects,
llvm::SmallVectorImpl<mlir::Value> &operands,
llvm::SmallVectorImpl<const semantics::Symbol *> &useDeviceSyms) {
genObjectList(objects, converter, operands);
for (mlir::Value &operand : operands)
checkMapType(operand.getLoc(), operand.getType());

for (const omp::Object &object : objects)
useDeviceSyms.push_back(object.sym());
}

//===----------------------------------------------------------------------===//
// ClauseProcessor unique clauses
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -1159,14 +1138,23 @@ bool ClauseProcessor::processInReduction(
}

bool ClauseProcessor::processIsDevicePtr(
mlir::omp::IsDevicePtrClauseOps &result,
lower::StatementContext &stmtCtx, mlir::omp::IsDevicePtrClauseOps &result,
llvm::SmallVectorImpl<const semantics::Symbol *> &isDeviceSyms) const {
return findRepeatableClause<omp::clause::IsDevicePtr>(
[&](const omp::clause::IsDevicePtr &devPtrClause,
const parser::CharBlock &) {
addUseDeviceClause(converter, devPtrClause.v, result.isDevicePtrVars,
isDeviceSyms);
std::map<Object, OmpMapParentAndMemberData> parentMemberIndices;
bool clauseFound = findRepeatableClause<omp::clause::IsDevicePtr>(
[&](const omp::clause::IsDevicePtr &clause,
const parser::CharBlock &source) {
mlir::Location location = converter.genLocation(source);
mlir::omp::ClauseMapFlags mapTypeBits =
mlir::omp::ClauseMapFlags::is_device_ptr;
processMapObjects(stmtCtx, location, clause.v, mapTypeBits,
parentMemberIndices, result.isDevicePtrVars,
isDeviceSyms);
});

insertChildMapInfoIntoParent(converter, semaCtx, stmtCtx, parentMemberIndices,
result.isDevicePtrVars, isDeviceSyms);
return clauseFound;
}

bool ClauseProcessor::processLinear(mlir::omp::LinearClauseOps &result) const {
Expand Down
2 changes: 1 addition & 1 deletion flang/lib/Lower/OpenMP/ClauseProcessor.h
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ class ClauseProcessor {
mlir::Location currentLocation, mlir::omp::InReductionClauseOps &result,
llvm::SmallVectorImpl<const semantics::Symbol *> &outReductionSyms) const;
bool processIsDevicePtr(
mlir::omp::IsDevicePtrClauseOps &result,
lower::StatementContext &stmtCtx, mlir::omp::IsDevicePtrClauseOps &result,
llvm::SmallVectorImpl<const semantics::Symbol *> &isDeviceSyms) const;
bool processLinear(mlir::omp::LinearClauseOps &result) const;
bool
Expand Down
43 changes: 39 additions & 4 deletions flang/lib/Lower/OpenMP/OpenMP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1673,7 +1673,7 @@ static void genTargetClauses(
hostEvalInfo->collectValues(clauseOps.hostEvalVars);
}
cp.processIf(llvm::omp::Directive::OMPD_target, clauseOps);
cp.processIsDevicePtr(clauseOps, isDevicePtrSyms);
cp.processIsDevicePtr(stmtCtx, clauseOps, isDevicePtrSyms);
cp.processMap(loc, stmtCtx, clauseOps, llvm::omp::Directive::OMPD_unknown,
&mapSyms);
cp.processNowait(clauseOps);
Expand Down Expand Up @@ -2485,13 +2485,15 @@ static bool isDuplicateMappedSymbol(
const semantics::Symbol &sym,
const llvm::SetVector<const semantics::Symbol *> &privatizedSyms,
const llvm::SmallVectorImpl<const semantics::Symbol *> &hasDevSyms,
const llvm::SmallVectorImpl<const semantics::Symbol *> &mappedSyms) {
const llvm::SmallVectorImpl<const semantics::Symbol *> &mappedSyms,
const llvm::SmallVectorImpl<const semantics::Symbol *> &isDevicePtrSyms) {
llvm::SmallVector<const semantics::Symbol *> concatSyms;
concatSyms.reserve(privatizedSyms.size() + hasDevSyms.size() +
mappedSyms.size());
mappedSyms.size() + isDevicePtrSyms.size());
concatSyms.append(privatizedSyms.begin(), privatizedSyms.end());
concatSyms.append(hasDevSyms.begin(), hasDevSyms.end());
concatSyms.append(mappedSyms.begin(), mappedSyms.end());
concatSyms.append(isDevicePtrSyms.begin(), isDevicePtrSyms.end());

auto checkSymbol = [&](const semantics::Symbol &checkSym) {
return std::any_of(concatSyms.begin(), concatSyms.end(),
Expand Down Expand Up @@ -2531,6 +2533,39 @@ genTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
loc, clauseOps, defaultMaps, hasDeviceAddrSyms,
isDevicePtrSyms, mapSyms);

if (!isDevicePtrSyms.empty()) {
// is_device_ptr maps get duplicated so the clause and synthesized
// has_device_addr entry each own a unique MapInfoOp user, keeping
// MapInfoFinalization happy while still wiring the symbol into
// has_device_addr when the user didn’t spell it explicitly.
auto insertionPt = firOpBuilder.saveInsertionPoint();
auto alreadyPresent = [&](const semantics::Symbol *sym) {
return llvm::any_of(hasDeviceAddrSyms, [&](const semantics::Symbol *s) {
return s && sym && s->GetUltimate() == sym->GetUltimate();
});
};

for (auto [idx, sym] : llvm::enumerate(isDevicePtrSyms)) {
mlir::Value mapVal = clauseOps.isDevicePtrVars[idx];
assert(sym && "expected symbol for is_device_ptr");
assert(mapVal && "expected map value for is_device_ptr");
auto mapInfo = mapVal.getDefiningOp<mlir::omp::MapInfoOp>();
assert(mapInfo && "expected map info op");

if (!alreadyPresent(sym)) {
clauseOps.hasDeviceAddrVars.push_back(mapVal);
hasDeviceAddrSyms.push_back(sym);
}

firOpBuilder.setInsertionPointAfter(mapInfo);
auto clonedOp = firOpBuilder.clone(*mapInfo.getOperation());
auto clonedMapInfo = mlir::dyn_cast<mlir::omp::MapInfoOp>(clonedOp);
assert(clonedMapInfo && "expected cloned map info op");
clauseOps.isDevicePtrVars[idx] = clonedMapInfo.getResult();
}
firOpBuilder.restoreInsertionPoint(insertionPt);
}

DataSharingProcessor dsp(converter, semaCtx, item->clauses, eval,
/*shouldCollectPreDeterminedSymbols=*/
lower::omp::isLastItemInQueue(item, queue),
Expand Down Expand Up @@ -2570,7 +2605,7 @@ genTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
return;

if (!isDuplicateMappedSymbol(sym, dsp.getAllSymbolsToPrivatize(),
hasDeviceAddrSyms, mapSyms)) {
hasDeviceAddrSyms, mapSyms, isDevicePtrSyms)) {
if (const auto *details =
sym.template detailsIf<semantics::HostAssocDetails>())
converter.copySymbolBinding(details->symbol(), sym);
Expand Down
30 changes: 30 additions & 0 deletions flang/test/Lower/OpenMP/target.f90
Original file line number Diff line number Diff line change
Expand Up @@ -566,6 +566,36 @@ subroutine omp_target_device_addr
end subroutine omp_target_device_addr


!===============================================================================
! Target `is_device_ptr` clause
!===============================================================================

!CHECK-LABEL: func.func @_QPomp_target_is_device_ptr() {
subroutine omp_target_is_device_ptr
use iso_c_binding, only: c_ptr
implicit none
integer :: i
integer :: arr(4)
type(c_ptr) :: p

i = 0
arr = 0

!CHECK: %[[P_STORAGE:.*]] = omp.map.info {{.*}}{name = "p"}
!CHECK: %[[P_IS:.*]] = omp.map.info {{.*}}{name = "p"}
!CHECK: %[[ARR_MAP:.*]] = omp.map.info {{.*}}{name = "arr"}
!CHECK: omp.target is_device_ptr(%[[P_IS]] :
!CHECK-SAME: has_device_addr(%[[P_STORAGE]] ->
!CHECK-SAME: map_entries({{.*}}%[[ARR_MAP]] ->
!$omp target is_device_ptr(p)
i = i + 1
arr(1) = i
!$omp end target
!CHECK: omp.terminator
!CHECK: }
end subroutine omp_target_is_device_ptr


!===============================================================================
! Target Data with unstructured code
!===============================================================================
Expand Down
4 changes: 3 additions & 1 deletion mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ def ClauseMapFlagsAttachAuto : I32BitEnumAttrCaseBit<"attach_auto", 15>;
def ClauseMapFlagsRefPtr : I32BitEnumAttrCaseBit<"ref_ptr", 16>;
def ClauseMapFlagsRefPtee : I32BitEnumAttrCaseBit<"ref_ptee", 17>;
def ClauseMapFlagsRefPtrPtee : I32BitEnumAttrCaseBit<"ref_ptr_ptee", 18>;
def ClauseMapFlagsIsDevicePtr : I32BitEnumAttrCaseBit<"is_device_ptr", 19>;

def ClauseMapFlags : OpenMP_BitEnumAttr<
"ClauseMapFlags",
Expand All @@ -149,7 +150,8 @@ def ClauseMapFlags : OpenMP_BitEnumAttr<
ClauseMapFlagsAttachAuto,
ClauseMapFlagsRefPtr,
ClauseMapFlagsRefPtee,
ClauseMapFlagsRefPtrPtee
ClauseMapFlagsRefPtrPtee,
ClauseMapFlagsIsDevicePtr
]>;

def ClauseMapFlagsAttr : OpenMP_EnumAttr<ClauseMapFlags,
Expand Down
5 changes: 5 additions & 0 deletions mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1817,6 +1817,9 @@ static ParseResult parseMapClause(OpAsmParser &parser,
if (mapTypeMod == "ref_ptr_ptee")
mapTypeBits |= ClauseMapFlags::ref_ptr_ptee;

if (mapTypeMod == "is_device_ptr")
mapTypeBits |= ClauseMapFlags::is_device_ptr;

return success();
};

Expand Down Expand Up @@ -1886,6 +1889,8 @@ static void printMapClause(OpAsmPrinter &p, Operation *op,
mapTypeStrs.push_back("ref_ptee");
if (mapTypeToBool(mapFlags, ClauseMapFlags::ref_ptr_ptee))
mapTypeStrs.push_back("ref_ptr_ptee");
if (mapTypeToBool(mapFlags, ClauseMapFlags::is_device_ptr))
mapTypeStrs.push_back("is_device_ptr");
if (mapFlags == ClauseMapFlags::none)
mapTypeStrs.push_back("none");

Expand Down