Skip to content

Commit b35269c

Browse files
TIFitisHoney Goyal
authored andcommitted
Reland "[Flang][OpenMP] Add lowering support for is_device_ptr clause (llvm#169331)" (llvm#170851)
Add support for OpenMP is_device_ptr clause for target directives. [MLIR][OpenMP] Add OpenMPToLLVMIRTranslation support for is_device_ptr llvm#169367 This PR adds support for the OpenMP is_device_ptr clause in the MLIR to LLVM IR translation for target regions. The is_device_ptr clause allows device pointers (allocated via OpenMP runtime APIs) to be used directly in target regions without implicit mapping.
1 parent 36637fd commit b35269c

File tree

11 files changed

+199
-51
lines changed

11 files changed

+199
-51
lines changed

flang/lib/Lower/OpenMP/ClauseProcessor.cpp

Lines changed: 18 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -44,15 +44,6 @@ mlir::omp::ReductionModifier translateReductionModifier(ReductionModifier mod) {
4444
return mlir::omp::ReductionModifier::defaultmod;
4545
}
4646

47-
/// Check for unsupported map operand types.
48-
static void checkMapType(mlir::Location location, mlir::Type type) {
49-
if (auto refType = mlir::dyn_cast<fir::ReferenceType>(type))
50-
type = refType.getElementType();
51-
if (auto boxType = mlir::dyn_cast_or_null<fir::BoxType>(type))
52-
if (!mlir::isa<fir::PointerType>(boxType.getElementType()))
53-
TODO(location, "OMPD_target_data MapOperand BoxType");
54-
}
55-
5647
static mlir::omp::ScheduleModifier
5748
translateScheduleModifier(const omp::clause::Schedule::OrderingModifier &m) {
5849
switch (m) {
@@ -211,18 +202,6 @@ getIfClauseOperand(lower::AbstractConverter &converter,
211202
ifVal);
212203
}
213204

214-
static void addUseDeviceClause(
215-
lower::AbstractConverter &converter, const omp::ObjectList &objects,
216-
llvm::SmallVectorImpl<mlir::Value> &operands,
217-
llvm::SmallVectorImpl<const semantics::Symbol *> &useDeviceSyms) {
218-
genObjectList(objects, converter, operands);
219-
for (mlir::Value &operand : operands)
220-
checkMapType(operand.getLoc(), operand.getType());
221-
222-
for (const omp::Object &object : objects)
223-
useDeviceSyms.push_back(object.sym());
224-
}
225-
226205
//===----------------------------------------------------------------------===//
227206
// ClauseProcessor unique clauses
228207
//===----------------------------------------------------------------------===//
@@ -1225,14 +1204,26 @@ bool ClauseProcessor::processInReduction(
12251204
}
12261205

12271206
bool ClauseProcessor::processIsDevicePtr(
1228-
mlir::omp::IsDevicePtrClauseOps &result,
1207+
lower::StatementContext &stmtCtx, mlir::omp::IsDevicePtrClauseOps &result,
12291208
llvm::SmallVectorImpl<const semantics::Symbol *> &isDeviceSyms) const {
1230-
return findRepeatableClause<omp::clause::IsDevicePtr>(
1231-
[&](const omp::clause::IsDevicePtr &devPtrClause,
1232-
const parser::CharBlock &) {
1233-
addUseDeviceClause(converter, devPtrClause.v, result.isDevicePtrVars,
1234-
isDeviceSyms);
1209+
std::map<Object, OmpMapParentAndMemberData> parentMemberIndices;
1210+
bool clauseFound = findRepeatableClause<omp::clause::IsDevicePtr>(
1211+
[&](const omp::clause::IsDevicePtr &clause,
1212+
const parser::CharBlock &source) {
1213+
mlir::Location location = converter.genLocation(source);
1214+
// Force a map so the descriptor is materialized on the device with the
1215+
// device address inside.
1216+
mlir::omp::ClauseMapFlags mapTypeBits =
1217+
mlir::omp::ClauseMapFlags::is_device_ptr |
1218+
mlir::omp::ClauseMapFlags::to;
1219+
processMapObjects(stmtCtx, location, clause.v, mapTypeBits,
1220+
parentMemberIndices, result.isDevicePtrVars,
1221+
isDeviceSyms);
12351222
});
1223+
1224+
insertChildMapInfoIntoParent(converter, semaCtx, stmtCtx, parentMemberIndices,
1225+
result.isDevicePtrVars, isDeviceSyms);
1226+
return clauseFound;
12361227
}
12371228

12381229
bool ClauseProcessor::processLinear(mlir::omp::LinearClauseOps &result) const {

flang/lib/Lower/OpenMP/ClauseProcessor.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ class ClauseProcessor {
134134
mlir::Location currentLocation, mlir::omp::InReductionClauseOps &result,
135135
llvm::SmallVectorImpl<const semantics::Symbol *> &outReductionSyms) const;
136136
bool processIsDevicePtr(
137-
mlir::omp::IsDevicePtrClauseOps &result,
137+
lower::StatementContext &stmtCtx, mlir::omp::IsDevicePtrClauseOps &result,
138138
llvm::SmallVectorImpl<const semantics::Symbol *> &isDeviceSyms) const;
139139
bool processLinear(mlir::omp::LinearClauseOps &result) const;
140140
bool

flang/lib/Lower/OpenMP/OpenMP.cpp

Lines changed: 38 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1671,7 +1671,7 @@ static void genTargetClauses(
16711671
hostEvalInfo->collectValues(clauseOps.hostEvalVars);
16721672
}
16731673
cp.processIf(llvm::omp::Directive::OMPD_target, clauseOps);
1674-
cp.processIsDevicePtr(clauseOps, isDevicePtrSyms);
1674+
cp.processIsDevicePtr(stmtCtx, clauseOps, isDevicePtrSyms);
16751675
cp.processMap(loc, stmtCtx, clauseOps, llvm::omp::Directive::OMPD_unknown,
16761676
&mapSyms);
16771677
cp.processNowait(clauseOps);
@@ -2487,13 +2487,15 @@ static bool isDuplicateMappedSymbol(
24872487
const semantics::Symbol &sym,
24882488
const llvm::SetVector<const semantics::Symbol *> &privatizedSyms,
24892489
const llvm::SmallVectorImpl<const semantics::Symbol *> &hasDevSyms,
2490-
const llvm::SmallVectorImpl<const semantics::Symbol *> &mappedSyms) {
2490+
const llvm::SmallVectorImpl<const semantics::Symbol *> &mappedSyms,
2491+
const llvm::SmallVectorImpl<const semantics::Symbol *> &isDevicePtrSyms) {
24912492
llvm::SmallVector<const semantics::Symbol *> concatSyms;
24922493
concatSyms.reserve(privatizedSyms.size() + hasDevSyms.size() +
2493-
mappedSyms.size());
2494+
mappedSyms.size() + isDevicePtrSyms.size());
24942495
concatSyms.append(privatizedSyms.begin(), privatizedSyms.end());
24952496
concatSyms.append(hasDevSyms.begin(), hasDevSyms.end());
24962497
concatSyms.append(mappedSyms.begin(), mappedSyms.end());
2498+
concatSyms.append(isDevicePtrSyms.begin(), isDevicePtrSyms.end());
24972499

24982500
auto checkSymbol = [&](const semantics::Symbol &checkSym) {
24992501
return std::any_of(concatSyms.begin(), concatSyms.end(),
@@ -2533,6 +2535,38 @@ genTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
25332535
loc, clauseOps, defaultMaps, hasDeviceAddrSyms,
25342536
isDevicePtrSyms, mapSyms);
25352537

2538+
if (!isDevicePtrSyms.empty()) {
2539+
// is_device_ptr maps get duplicated so the clause and synthesized
2540+
// has_device_addr entry each own a unique MapInfoOp user, keeping
2541+
// MapInfoFinalization happy while still wiring the symbol into
2542+
// has_device_addr when the user didn’t spell it explicitly.
2543+
auto insertionPt = firOpBuilder.saveInsertionPoint();
2544+
auto alreadyPresent = [&](const semantics::Symbol *sym) {
2545+
return llvm::any_of(hasDeviceAddrSyms, [&](const semantics::Symbol *s) {
2546+
return s && sym && s->GetUltimate() == sym->GetUltimate();
2547+
});
2548+
};
2549+
2550+
for (auto [idx, sym] : llvm::enumerate(isDevicePtrSyms)) {
2551+
mlir::Value mapVal = clauseOps.isDevicePtrVars[idx];
2552+
assert(sym && "expected symbol for is_device_ptr");
2553+
assert(mapVal && "expected map value for is_device_ptr");
2554+
auto mapInfo = mapVal.getDefiningOp<mlir::omp::MapInfoOp>();
2555+
assert(mapInfo && "expected map info op");
2556+
2557+
if (!alreadyPresent(sym)) {
2558+
clauseOps.hasDeviceAddrVars.push_back(mapVal);
2559+
hasDeviceAddrSyms.push_back(sym);
2560+
}
2561+
2562+
firOpBuilder.setInsertionPointAfter(mapInfo);
2563+
mlir::Operation *clonedOp = firOpBuilder.clone(*mapInfo.getOperation());
2564+
auto clonedMapInfo = mlir::cast<mlir::omp::MapInfoOp>(clonedOp);
2565+
clauseOps.isDevicePtrVars[idx] = clonedMapInfo.getResult();
2566+
}
2567+
firOpBuilder.restoreInsertionPoint(insertionPt);
2568+
}
2569+
25362570
DataSharingProcessor dsp(converter, semaCtx, item->clauses, eval,
25372571
/*shouldCollectPreDeterminedSymbols=*/
25382572
lower::omp::isLastItemInQueue(item, queue),
@@ -2572,7 +2606,7 @@ genTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
25722606
return;
25732607

25742608
if (!isDuplicateMappedSymbol(sym, dsp.getAllSymbolsToPrivatize(),
2575-
hasDeviceAddrSyms, mapSyms)) {
2609+
hasDeviceAddrSyms, mapSyms, isDevicePtrSyms)) {
25762610
if (const auto *details =
25772611
sym.template detailsIf<semantics::HostAssocDetails>())
25782612
converter.copySymbolBinding(details->symbol(), sym);

flang/test/Integration/OpenMP/map-types-and-sizes.f90

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,15 @@ subroutine mapType_array
3333
!$omp end target
3434
end subroutine mapType_array
3535

36+
!CHECK: @.offload_sizes{{.*}} = private unnamed_addr constant [1 x i64] [i64 8]
37+
!CHECK: @.offload_maptypes{{.*}} = private unnamed_addr constant [1 x i64] [i64 33]
38+
subroutine mapType_is_device_ptr
39+
use iso_c_binding, only : c_ptr
40+
type(c_ptr) :: p
41+
!$omp target is_device_ptr(p)
42+
!$omp end target
43+
end subroutine mapType_is_device_ptr
44+
3645
!CHECK: @.offload_sizes{{.*}} = private unnamed_addr constant [5 x i64] [i64 0, i64 0, i64 0, i64 8, i64 0]
3746
!CHECK: @.offload_maptypes{{.*}} = private unnamed_addr constant [5 x i64] [i64 32, i64 281474976711173, i64 281474976711173, i64 281474976711171, i64 281474976711187]
3847
subroutine mapType_ptr

flang/test/Lower/OpenMP/target.f90

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -566,6 +566,36 @@ subroutine omp_target_device_addr
566566
end subroutine omp_target_device_addr
567567

568568

569+
!===============================================================================
570+
! Target `is_device_ptr` clause
571+
!===============================================================================
572+
573+
!CHECK-LABEL: func.func @_QPomp_target_is_device_ptr() {
574+
subroutine omp_target_is_device_ptr
575+
use iso_c_binding, only: c_ptr
576+
implicit none
577+
integer :: i
578+
integer :: arr(4)
579+
type(c_ptr) :: p
580+
581+
i = 0
582+
arr = 0
583+
584+
!CHECK: %[[P_STORAGE:.*]] = omp.map.info {{.*}}{name = "p"}
585+
!CHECK: %[[P_IS:.*]] = omp.map.info {{.*}}{name = "p"}
586+
!CHECK: %[[ARR_MAP:.*]] = omp.map.info {{.*}}{name = "arr"}
587+
!CHECK: omp.target is_device_ptr(%[[P_IS]] :
588+
!CHECK-SAME: has_device_addr(%[[P_STORAGE]] ->
589+
!CHECK-SAME: map_entries({{.*}}%[[ARR_MAP]] ->
590+
!$omp target is_device_ptr(p)
591+
i = i + 1
592+
arr(1) = i
593+
!$omp end target
594+
!CHECK: omp.terminator
595+
!CHECK: }
596+
end subroutine omp_target_is_device_ptr
597+
598+
569599
!===============================================================================
570600
! Target Data with unstructured code
571601
!===============================================================================

mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,7 @@ def ClauseMapFlagsAttachAuto : I32BitEnumAttrCaseBit<"attach_auto", 15>;
128128
def ClauseMapFlagsRefPtr : I32BitEnumAttrCaseBit<"ref_ptr", 16>;
129129
def ClauseMapFlagsRefPtee : I32BitEnumAttrCaseBit<"ref_ptee", 17>;
130130
def ClauseMapFlagsRefPtrPtee : I32BitEnumAttrCaseBit<"ref_ptr_ptee", 18>;
131+
def ClauseMapFlagsIsDevicePtr : I32BitEnumAttrCaseBit<"is_device_ptr", 19>;
131132

132133
def ClauseMapFlags : OpenMP_BitEnumAttr<
133134
"ClauseMapFlags",
@@ -151,7 +152,8 @@ def ClauseMapFlags : OpenMP_BitEnumAttr<
151152
ClauseMapFlagsAttachAuto,
152153
ClauseMapFlagsRefPtr,
153154
ClauseMapFlagsRefPtee,
154-
ClauseMapFlagsRefPtrPtee
155+
ClauseMapFlagsRefPtrPtee,
156+
ClauseMapFlagsIsDevicePtr
155157
]>;
156158

157159
def ClauseMapFlagsAttr : OpenMP_EnumAttr<ClauseMapFlags,

mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1818,6 +1818,9 @@ static ParseResult parseMapClause(OpAsmParser &parser,
18181818
if (mapTypeMod == "ref_ptr_ptee")
18191819
mapTypeBits |= ClauseMapFlags::ref_ptr_ptee;
18201820

1821+
if (mapTypeMod == "is_device_ptr")
1822+
mapTypeBits |= ClauseMapFlags::is_device_ptr;
1823+
18211824
return success();
18221825
};
18231826

@@ -1887,6 +1890,8 @@ static void printMapClause(OpAsmPrinter &p, Operation *op,
18871890
mapTypeStrs.push_back("ref_ptee");
18881891
if (mapTypeToBool(mapFlags, ClauseMapFlags::ref_ptr_ptee))
18891892
mapTypeStrs.push_back("ref_ptr_ptee");
1893+
if (mapTypeToBool(mapFlags, ClauseMapFlags::is_device_ptr))
1894+
mapTypeStrs.push_back("is_device_ptr");
18901895
if (mapFlags == ClauseMapFlags::none)
18911896
mapTypeStrs.push_back("none");
18921897

mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -332,10 +332,6 @@ static LogicalResult checkImplementationStatus(Operation &op) {
332332
op.getInReductionSyms())
333333
result = todo("in_reduction");
334334
};
335-
auto checkIsDevicePtr = [&todo](auto op, LogicalResult &result) {
336-
if (!op.getIsDevicePtrVars().empty())
337-
result = todo("is_device_ptr");
338-
};
339335
auto checkNowait = [&todo](auto op, LogicalResult &result) {
340336
if (op.getNowait())
341337
result = todo("nowait");
@@ -435,7 +431,6 @@ static LogicalResult checkImplementationStatus(Operation &op) {
435431
checkBare(op, result);
436432
checkDevice(op, result);
437433
checkInReduction(op, result);
438-
checkIsDevicePtr(op, result);
439434
})
440435
.Default([](Operation &) {
441436
// Assume all clauses for an operation can be translated unless they are
@@ -3986,6 +3981,9 @@ convertClauseMapFlags(omp::ClauseMapFlags mlirFlags) {
39863981
auto mapTypeToBool = [&mlirFlags](omp::ClauseMapFlags flag) {
39873982
return (mlirFlags & flag) == flag;
39883983
};
3984+
const bool hasExplicitMap =
3985+
(mlirFlags & ~omp::ClauseMapFlags::is_device_ptr) !=
3986+
omp::ClauseMapFlags::none;
39893987

39903988
llvm::omp::OpenMPOffloadMappingFlags mapType =
39913989
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE;
@@ -4026,6 +4024,12 @@ convertClauseMapFlags(omp::ClauseMapFlags mlirFlags) {
40264024
if (mapTypeToBool(omp::ClauseMapFlags::attach))
40274025
mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ATTACH;
40284026

4027+
if (mapTypeToBool(omp::ClauseMapFlags::is_device_ptr)) {
4028+
mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;
4029+
if (!hasExplicitMap)
4030+
mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL;
4031+
}
4032+
40294033
return mapType;
40304034
}
40314035

@@ -4149,6 +4153,9 @@ static void collectMapDataFromMapOperands(
41494153
llvm::Value *origValue = moduleTranslation.lookupValue(offloadPtr);
41504154
auto mapType = convertClauseMapFlags(mapOp.getMapType());
41514155
auto mapTypeAlways = llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS;
4156+
bool isDevicePtr =
4157+
(mapOp.getMapType() & omp::ClauseMapFlags::is_device_ptr) !=
4158+
omp::ClauseMapFlags::none;
41524159

41534160
mapData.OriginalValue.push_back(origValue);
41544161
mapData.BasePointers.push_back(origValue);
@@ -4175,14 +4182,18 @@ static void collectMapDataFromMapOperands(
41754182
mapData.Mappers.push_back(nullptr);
41764183
}
41774184
} else {
4185+
// For is_device_ptr we need the map type to propagate so the runtime
4186+
// can materialize the device-side copy of the pointer container.
41784187
mapData.Types.push_back(
4179-
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL);
4188+
isDevicePtr ? mapType
4189+
: llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL);
41804190
mapData.Mappers.push_back(nullptr);
41814191
}
41824192
mapData.Names.push_back(LLVM::createMappingInformation(
41834193
mapOp.getLoc(), *moduleTranslation.getOpenMPBuilder()));
41844194
mapData.DevicePointers.push_back(
4185-
llvm::OpenMPIRBuilder::DeviceInfoTy::Address);
4195+
isDevicePtr ? llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer
4196+
: llvm::OpenMPIRBuilder::DeviceInfoTy::Address);
41864197
mapData.IsAMapping.push_back(false);
41874198
mapData.IsAMember.push_back(checkIsAMember(hasDevAddrOperands, mapOp));
41884199
}

mlir/test/Target/LLVMIR/omptarget-llvm.mlir

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -622,3 +622,20 @@ module attributes {omp.target_triples = ["amdgcn-amd-amdhsa"]} {
622622
// CHECK: br label %[[VAL_40]]
623623
// CHECK: omp.done: ; preds = %[[VAL_68]], %[[VAL_63]], %[[VAL_32]]
624624
// CHECK: ret void
625+
626+
// -----
627+
628+
module attributes {omp.target_triples = ["amdgcn-amd-amdhsa"]} {
629+
llvm.func @_QPomp_target_is_device_ptr(%arg0 : !llvm.ptr) {
630+
%map = omp.map.info var_ptr(%arg0 : !llvm.ptr, !llvm.ptr)
631+
map_clauses(is_device_ptr) capture(ByRef) -> !llvm.ptr {name = ""}
632+
omp.target map_entries(%map -> %ptr_arg : !llvm.ptr) {
633+
omp.terminator
634+
}
635+
llvm.return
636+
}
637+
}
638+
639+
// CHECK: @.offload_sizes = private unnamed_addr constant [1 x i64] [i64 8]
640+
// CHECK: @.offload_maptypes = private unnamed_addr constant [1 x i64] [i64 288]
641+
// CHECK-LABEL: define void @_QPomp_target_is_device_ptr

mlir/test/Target/LLVMIR/openmp-todo.mlir

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -212,17 +212,6 @@ llvm.func @target_in_reduction(%x : !llvm.ptr) {
212212

213213
// -----
214214

215-
llvm.func @target_is_device_ptr(%x : !llvm.ptr) {
216-
// expected-error@below {{not yet implemented: Unhandled clause is_device_ptr in omp.target operation}}
217-
// expected-error@below {{LLVM Translation failed for operation: omp.target}}
218-
omp.target is_device_ptr(%x : !llvm.ptr) {
219-
omp.terminator
220-
}
221-
llvm.return
222-
}
223-
224-
// -----
225-
226215
llvm.func @target_enter_data_depend(%x: !llvm.ptr) {
227216
// expected-error@below {{not yet implemented: Unhandled clause depend in omp.target_enter_data operation}}
228217
// expected-error@below {{LLVM Translation failed for operation: omp.target_enter_data}}

0 commit comments

Comments
 (0)