Skip to content

Commit 2f3acbd

Browse files
authored
[MLIR][OpenMP] Introduce host_eval clause to omp.target (#178)
This patch defines a map-like clause named `host_eval` used to capture host values for use inside of target regions on restricted cases: - As `num_teams` or `thread_limit` of a nested `omp.target` operation. - As `num_threads` of a nested `omp.parallel` operation or as bounds or steps of a nested `omp.loop_nest`, if it is a target SPMD kernel. This replaces the following `omp.target` arguments: `trip_count`, `num_threads`, `num_teams_lower`, `num_teams_upper` and `teams_thread_limit`.
1 parent bcb94bb commit 2f3acbd

File tree

7 files changed

+296
-60
lines changed

7 files changed

+296
-60
lines changed

mlir/docs/Dialects/OpenMPDialect/_index.md

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -297,7 +297,8 @@ arguments for the region of that MLIR operation. This enables, for example, the
297297
introduction of private copies of the same underlying variable defined outside
298298
the MLIR operation the clause is attached to. Currently, clauses with this
299299
property can be classified into three main categories:
300-
- Map-like clauses: `map`, `use_device_addr` and `use_device_ptr`.
300+
- Map-like clauses: `host_eval`, `map`, `use_device_addr` and
301+
`use_device_ptr`.
301302
- Reduction-like clauses: `in_reduction`, `reduction` and `task_reduction`.
302303
- Privatization clauses: `private`.
303304

@@ -522,3 +523,58 @@ omp.parallel ... {
522523
omp.terminator
523524
} {omp.composite}
524525
```
526+
527+
## Host-Evaluated Clauses in Target Regions
528+
529+
The `omp.target` operation, which represents the OpenMP `target` construct, is
530+
marked with the `IsolatedFromAbove` trait. This means that, inside of its
531+
region, no MLIR values defined outside of the op itself can be used. This is
532+
consistent with the OpenMP specification of the `target` construct, which
533+
mandates that all host device values used inside of the `target` region must
534+
either be privatized (data-sharing) or mapped (data-mapping).
535+
536+
Normally, clauses applied to a construct are evaluated before entering that
537+
construct. Further, in some cases, the OpenMP specification stipulates that
538+
clauses be evaluated _on the host device_ on entry to a parent `target`
539+
construct. In particular, the `num_teams` and `thread_limit` clauses of the
540+
`teams` construct must be evaluated on the host device if it's nested inside or
541+
combined with a `target` construct.
542+
543+
Additionally, the runtime library targeted by the MLIR to LLVM IR translation of
544+
the OpenMP dialect supports the optimized launch of SPMD kernels (i.e.
545+
`target teams distribute parallel {do,for}` in OpenMP), which requires
546+
specifying in advance what the total trip count of the loop is. Consequently, it
547+
is also beneficial to evaluate the trip count on the host device prior to the
548+
kernel launch.
549+
550+
These host-evaluated values in MLIR would need to be placed outside of the
551+
`omp.target` region and also attached to the corresponding nested operations,
552+
which is not possible because of the `IsolatedFromAbove` trait. The solution
553+
implemented to address this problem has been to introduce the `host_eval`
554+
argument to the `omp.target` operation. It works similarly to a `map` clause,
555+
but its only intended use is to forward host-evaluated values to their
556+
corresponding operation inside of the region. Any uses outside of the previously
557+
described result in a verifier error.
558+
559+
```mlir
560+
// Initialize %0, %1, %2, %3...
561+
omp.target host_eval(%0 -> %nt, %1 -> %lb, %2 -> %ub, %3 -> %step : i32, i32, i32, i32) {
562+
omp.teams num_teams(to %nt : i32) {
563+
omp.parallel {
564+
omp.distribute {
565+
omp.wsloop {
566+
omp.loop_nest (%iv) : i32 = (%lb) to (%ub) step (%step) {
567+
// ...
568+
omp.yield
569+
}
570+
omp.terminator
571+
} {omp.composite}
572+
omp.terminator
573+
} {omp.composite}
574+
omp.terminator
575+
} {omp.composite}
576+
omp.terminator
577+
}
578+
omp.terminator
579+
}
580+
```

mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -419,6 +419,44 @@ class OpenMP_HintClauseSkip<
419419

420420
def OpenMP_HintClause : OpenMP_HintClauseSkip<>;
421421

422+
//===----------------------------------------------------------------------===//
423+
// Not in the spec: Clause-like structure to hold host-evaluated values.
424+
//===----------------------------------------------------------------------===//
425+
426+
class OpenMP_HostEvalClauseSkip<
427+
bit traits = false, bit arguments = false, bit assemblyFormat = false,
428+
bit description = false, bit extraClassDeclaration = false
429+
> : OpenMP_Clause<traits, arguments, assemblyFormat, description,
430+
extraClassDeclaration> {
431+
let traits = [
432+
BlockArgOpenMPOpInterface
433+
];
434+
435+
let arguments = (ins
436+
Variadic<AnyType>:$host_eval_vars
437+
);
438+
439+
let extraClassDeclaration = [{
440+
unsigned numHostEvalBlockArgs() {
441+
return getHostEvalVars().size();
442+
}
443+
}];
444+
445+
let description = [{
446+
The optional `host_eval_vars` holds values defined outside of the region of
447+
the `IsolatedFromAbove` operation for which a corresponding entry block
448+
argument is defined. The only legal uses for these captured values are the
449+
following:
450+
- `num_teams` or `thread_limit` clause of an immediately nested
451+
`omp.teams` operation.
452+
- If the operation is the top-level `omp.target` of a target SPMD kernel:
453+
- `num_threads` clause of the nested `omp.parallel` operation.
454+
- Bounds and steps of the nested `omp.loop_nest` operation.
455+
}];
456+
}
457+
458+
def OpenMP_HostEvalClause : OpenMP_HostEvalClauseSkip<>;
459+
422460
//===----------------------------------------------------------------------===//
423461
// V5.2: [3.4] `if` clause
424462
//===----------------------------------------------------------------------===//

mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td

Lines changed: 10 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1116,20 +1116,16 @@ def TargetUpdateOp: OpenMP_Op<"target_update", traits = [
11161116
// 2.14.5 target construct
11171117
//===----------------------------------------------------------------------===//
11181118

1119-
// TODO: Remove num_threads, teams_thread_limit and trip_count and implement the
1120-
// passthrough approach described here:
1121-
// https://discourse.llvm.org/t/rfc-openmp-dialect-representation-of-num-teams-thread-limit-and-target-spmd/81106.
11221119
def TargetOp : OpenMP_Op<"target", traits = [
11231120
AttrSizedOperandSegments, BlockArgOpenMPOpInterface, IsolatedFromAbove,
11241121
OutlineableOpenMPOpInterface
11251122
], clauses = [
11261123
// TODO: Complete clause list (defaultmap, uses_allocators).
11271124
OpenMP_AllocateClause, OpenMP_DependClause, OpenMP_DeviceClause,
1128-
OpenMP_HasDeviceAddrClause, OpenMP_IfClause, OpenMP_InReductionClause,
1129-
OpenMP_IsDevicePtrClause, OpenMP_MapClauseSkip<assemblyFormat = true>,
1130-
OpenMP_NowaitClause, OpenMP_NumTeamsClauseSkip<description = true>,
1131-
OpenMP_NumThreadsClauseSkip<description = true>, OpenMP_PrivateClause,
1132-
OpenMP_ThreadLimitClause
1125+
OpenMP_HasDeviceAddrClause, OpenMP_HostEvalClause, OpenMP_IfClause,
1126+
OpenMP_InReductionClause, OpenMP_IsDevicePtrClause,
1127+
OpenMP_MapClauseSkip<assemblyFormat = true>, OpenMP_NowaitClause,
1128+
OpenMP_PrivateClause, OpenMP_ThreadLimitClause
11331129
], singleRegion = true> {
11341130
let summary = "target construct";
11351131
let description = [{
@@ -1156,10 +1152,6 @@ def TargetOp : OpenMP_Op<"target", traits = [
11561152
an `omp.parallel`.
11571153
}] # clausesDescription;
11581154

1159-
let arguments = !con(clausesArgs,
1160-
(ins Optional<AnyInteger>:$trip_count,
1161-
Optional<AnyInteger>:$teams_thread_limit));
1162-
11631155
let builders = [
11641156
OpBuilder<(ins CArg<"const TargetOperands &">:$clauses)>
11651157
];
@@ -1184,15 +1176,12 @@ def TargetOp : OpenMP_Op<"target", traits = [
11841176
bool isTargetSPMDLoop();
11851177
}] # clausesExtraClassDeclaration;
11861178

1187-
let assemblyFormat = clausesReqAssemblyFormat #
1188-
" oilist(" # clausesOptAssemblyFormat # [{
1189-
| `trip_count` `(` $trip_count `:` type($trip_count) `)`
1190-
| `teams_thread_limit` `(` $teams_thread_limit `:` type($teams_thread_limit) `)`
1191-
}] # ")" # [{
1192-
custom<InReductionMapPrivateRegion>(
1193-
$region, $in_reduction_vars, type($in_reduction_vars),
1194-
$in_reduction_byref, $in_reduction_syms, $map_vars, type($map_vars),
1195-
$private_vars, type($private_vars), $private_syms) attr-dict
1179+
let assemblyFormat = clausesAssemblyFormat # [{
1180+
custom<HostEvalInReductionMapPrivateRegion>(
1181+
$region, $host_eval_vars, type($host_eval_vars), $in_reduction_vars,
1182+
type($in_reduction_vars), $in_reduction_byref, $in_reduction_syms,
1183+
$map_vars, type($map_vars), $private_vars, type($private_vars),
1184+
$private_syms) attr-dict
11961185
}];
11971186

11981187
let hasVerifier = 1;

mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,10 @@ def BlockArgOpenMPOpInterface : OpInterface<"BlockArgOpenMPOpInterface"> {
2525

2626
let methods = [
2727
// Default-implemented methods to be overriden by the corresponding clauses.
28+
InterfaceMethod<"Get number of block arguments defined by `host_eval`.",
29+
"unsigned", "numHostEvalBlockArgs", (ins), [{}], [{
30+
return 0;
31+
}]>,
2832
InterfaceMethod<"Get number of block arguments defined by `in_reduction`.",
2933
"unsigned", "numInReductionBlockArgs", (ins), [{}], [{
3034
return 0;
@@ -55,9 +59,14 @@ def BlockArgOpenMPOpInterface : OpInterface<"BlockArgOpenMPOpInterface"> {
5559
}]>,
5660

5761
// Unified access methods for clause-associated entry block arguments.
62+
InterfaceMethod<"Get start index of block arguments defined by `host_eval`.",
63+
"unsigned", "getHostEvalBlockArgsStart", (ins), [{
64+
return 0;
65+
}]>,
5866
InterfaceMethod<"Get start index of block arguments defined by `in_reduction`.",
5967
"unsigned", "getInReductionBlockArgsStart", (ins), [{
60-
return 0;
68+
auto iface = ::llvm::cast<BlockArgOpenMPOpInterface>(*$_op);
69+
return iface.getHostEvalBlockArgsStart() + $_op.numHostEvalBlockArgs();
6170
}]>,
6271
InterfaceMethod<"Get start index of block arguments defined by `map`.",
6372
"unsigned", "getMapBlockArgsStart", (ins), [{
@@ -91,6 +100,13 @@ def BlockArgOpenMPOpInterface : OpInterface<"BlockArgOpenMPOpInterface"> {
91100
return iface.getUseDeviceAddrBlockArgsStart() + $_op.numUseDeviceAddrBlockArgs();
92101
}]>,
93102

103+
InterfaceMethod<"Get block arguments defined by `host_eval`.",
104+
"::llvm::MutableArrayRef<::mlir::BlockArgument>",
105+
"getHostEvalBlockArgs", (ins), [{
106+
auto iface = ::llvm::cast<BlockArgOpenMPOpInterface>(*$_op);
107+
return $_op->getRegion(0).getArguments().slice(
108+
iface.getHostEvalBlockArgsStart(), $_op.numHostEvalBlockArgs());
109+
}]>,
94110
InterfaceMethod<"Get block arguments defined by `in_reduction`.",
95111
"::llvm::MutableArrayRef<::mlir::BlockArgument>",
96112
"getInReductionBlockArgs", (ins), [{
@@ -147,10 +163,11 @@ def BlockArgOpenMPOpInterface : OpInterface<"BlockArgOpenMPOpInterface"> {
147163

148164
let verify = [{
149165
auto iface = ::llvm::cast<BlockArgOpenMPOpInterface>($_op);
150-
unsigned expectedArgs = iface.numInReductionBlockArgs() +
151-
iface.numMapBlockArgs() + iface.numPrivateBlockArgs() +
152-
iface.numReductionBlockArgs() + iface.numTaskReductionBlockArgs() +
153-
iface.numUseDeviceAddrBlockArgs() + iface.numUseDevicePtrBlockArgs();
166+
unsigned expectedArgs = iface.numHostEvalBlockArgs() +
167+
iface.numInReductionBlockArgs() + iface.numMapBlockArgs() +
168+
iface.numPrivateBlockArgs() + iface.numReductionBlockArgs() +
169+
iface.numTaskReductionBlockArgs() + iface.numUseDeviceAddrBlockArgs() +
170+
iface.numUseDevicePtrBlockArgs();
154171
if ($_op->getRegion(0).getNumArguments() < expectedArgs)
155172
return $_op->emitOpError() << "expected at least " << expectedArgs
156173
<< " entry block argument(s)";

0 commit comments

Comments
 (0)