Skip to content

Commit 820c957

Browse files
committed
Address review comments
1 parent 50619d2 commit 820c957

File tree

2 files changed

+17
-18
lines changed

2 files changed

+17
-18
lines changed

mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1343,20 +1343,15 @@ def TargetOp : OpenMP_Op<"target", traits = [
13431343
///
13441344
/// If there are omp.loop_nest operations in the sequence of nested
13451345
/// operations, the top level one will be the one captured.
1346-
///
1347-
/// This is a relatively expensive operation, so if it is needed at the same
1348-
/// time as the result of `getKernelExecFlags` it is preferable to cache the
1349-
/// result of this function and pass it along.
13501346
Operation *getInnermostCapturedOmpOp();
13511347

13521348
/// Infers the kernel type (Generic, SPMD or Generic-SPMD) based on the
13531349
/// contents of the target region.
13541350
///
13551351
/// \param capturedOp result of a still valid (no modifications made to any
1356-
/// nested operations) previous call to `getInnermostCapturedOmpOp()`. If
1357-
/// not specified, this will call that function itself instead.
1358-
llvm::omp::OMPTgtExecModeFlags
1359-
getKernelExecFlags(std::optional<Operation *> capturedOp = std::nullopt);
1352+
/// nested operations) previous call to `getInnermostCapturedOmpOp()`.
1353+
static llvm::omp::OMPTgtExecModeFlags
1354+
getKernelExecFlags(Operation *capturedOp);
13601355
}] # clausesExtraClassDeclaration;
13611356

13621357
let assemblyFormat = clausesAssemblyFormat # [{

mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1905,7 +1905,8 @@ LogicalResult TargetOp::verifyRegions() {
19051905
return emitError("target containing multiple 'omp.teams' nested ops");
19061906

19071907
// Check that host_eval values are only used in legal ways.
1908-
llvm::omp::OMPTgtExecModeFlags execFlags = getKernelExecFlags();
1908+
llvm::omp::OMPTgtExecModeFlags execFlags =
1909+
getKernelExecFlags(getInnermostCapturedOmpOp());
19091910
for (Value hostEvalArg :
19101911
cast<BlockArgOpenMPOpInterface>(getOperation()).getHostEvalBlockArgs()) {
19111912
for (Operation *user : hostEvalArg.getUsers()) {
@@ -2026,23 +2027,26 @@ Operation *TargetOp::getInnermostCapturedOmpOp() {
20262027
}
20272028

20282029
llvm::omp::OMPTgtExecModeFlags
2029-
TargetOp::getKernelExecFlags(std::optional<Operation *> capturedOp) {
2030+
TargetOp::getKernelExecFlags(Operation *capturedOp) {
20302031
using namespace llvm::omp;
20312032

2032-
// Use a cached operation, if passed in. Otherwise, find the innermost
2033-
// captured operation.
2034-
if (!capturedOp)
2035-
capturedOp = getInnermostCapturedOmpOp();
2036-
assert(*capturedOp == getInnermostCapturedOmpOp() &&
2037-
"unexpected captured op");
2033+
#ifndef NDEBUG
2034+
if (capturedOp) {
2035+
// A non-null captured op is only valid if it resides inside of a TargetOp
2036+
// and is the result of calling getInnermostCapturedOmpOp() on it.
2037+
TargetOp targetOp = capturedOp->getParentOfType<TargetOp>();
2038+
assert(targetOp && targetOp.getInnermostCapturedOmpOp() &&
2039+
"unexpected captured op");
2040+
}
2041+
#endif
20382042

20392043
// Make sure this region is capturing a loop. Otherwise, it's a generic
20402044
// kernel.
2041-
if (!isa_and_present<LoopNestOp>(*capturedOp))
2045+
if (!isa_and_present<LoopNestOp>(capturedOp))
20422046
return OMP_TGT_EXEC_MODE_GENERIC;
20432047

20442048
SmallVector<LoopWrapperInterface> wrappers;
2045-
cast<LoopNestOp>(*capturedOp).gatherWrappers(wrappers);
2049+
cast<LoopNestOp>(capturedOp).gatherWrappers(wrappers);
20462050
assert(!wrappers.empty());
20472051

20482052
// Ignore optional SIMD leaf construct.

0 commit comments

Comments
 (0)