@@ -1974,8 +1974,9 @@ LogicalResult TargetOp::verifyRegions() {
19741974 return emitError (" target containing multiple 'omp.teams' nested ops" );
19751975
19761976 // Check that host_eval values are only used in legal ways.
1977+ bool hostEvalTripCount;
19771978 Operation *capturedOp = getInnermostCapturedOmpOp ();
1978- TargetRegionFlags execFlags = getKernelExecFlags (capturedOp);
1979+ TargetExecMode execMode = getKernelExecFlags (capturedOp, &hostEvalTripCount );
19791980 for (Value hostEvalArg :
19801981 cast<BlockArgOpenMPOpInterface>(getOperation ()).getHostEvalBlockArgs ()) {
19811982 for (Operation *user : hostEvalArg.getUsers ()) {
@@ -1990,7 +1991,7 @@ LogicalResult TargetOp::verifyRegions() {
19901991 " and 'thread_limit' in 'omp.teams'" ;
19911992 }
19921993 if (auto parallelOp = dyn_cast<ParallelOp>(user)) {
1993- if (bitEnumContainsAny (execFlags, TargetRegionFlags ::spmd) &&
1994+ if (execMode == TargetExecMode ::spmd &&
19941995 parallelOp->isAncestor (capturedOp) &&
19951996 hostEvalArg == parallelOp.getNumThreads ())
19961997 continue ;
@@ -2000,8 +2001,7 @@ LogicalResult TargetOp::verifyRegions() {
20002001 " 'omp.parallel' when representing target SPMD" ;
20012002 }
20022003 if (auto loopNestOp = dyn_cast<LoopNestOp>(user)) {
2003- if (bitEnumContainsAny (execFlags, TargetRegionFlags::trip_count) &&
2004- loopNestOp.getOperation () == capturedOp &&
2004+ if (hostEvalTripCount && loopNestOp.getOperation () == capturedOp &&
20052005 (llvm::is_contained (loopNestOp.getLoopLowerBounds (), hostEvalArg) ||
20062006 llvm::is_contained (loopNestOp.getLoopUpperBounds (), hostEvalArg) ||
20072007 llvm::is_contained (loopNestOp.getLoopSteps (), hostEvalArg)))
@@ -2106,7 +2106,9 @@ Operation *TargetOp::getInnermostCapturedOmpOp() {
21062106 });
21072107}
21082108
2109- TargetRegionFlags TargetOp::getKernelExecFlags (Operation *capturedOp) {
2109+ TargetExecMode TargetOp::getKernelExecFlags (Operation *capturedOp,
2110+ bool *hostEvalTripCount) {
2111+ // TODO: Support detection of bare kernel mode.
21102112 // A non-null captured op is only valid if it resides inside of a TargetOp
21112113 // and is the result of calling getInnermostCapturedOmpOp() on it.
21122114 TargetOp targetOp =
@@ -2115,9 +2117,12 @@ TargetRegionFlags TargetOp::getKernelExecFlags(Operation *capturedOp) {
21152117 (targetOp && targetOp.getInnermostCapturedOmpOp () == capturedOp)) &&
21162118 " unexpected captured op" );
21172119
2120+ if (hostEvalTripCount)
2121+ *hostEvalTripCount = false ;
2122+
21182123 // If it's not capturing a loop, it's a default target region.
21192124 if (!isa_and_present<LoopNestOp>(capturedOp))
2120- return TargetRegionFlags::none ;
2125+ return TargetExecMode::generic ;
21212126
21222127 // Get the innermost non-simd loop wrapper.
21232128 SmallVector<LoopWrapperInterface> loopWrappers;
@@ -2130,53 +2135,59 @@ TargetRegionFlags TargetOp::getKernelExecFlags(Operation *capturedOp) {
21302135
21312136 auto numWrappers = std::distance (innermostWrapper, loopWrappers.end ());
21322137 if (numWrappers != 1 && numWrappers != 2 )
2133- return TargetRegionFlags::none ;
2138+ return TargetExecMode::generic ;
21342139
21352140 // Detect target-teams-distribute-parallel-wsloop[-simd].
21362141 if (numWrappers == 2 ) {
21372142 if (!isa<WsloopOp>(innermostWrapper))
2138- return TargetRegionFlags::none ;
2143+ return TargetExecMode::generic ;
21392144
21402145 innermostWrapper = std::next (innermostWrapper);
21412146 if (!isa<DistributeOp>(innermostWrapper))
2142- return TargetRegionFlags::none ;
2147+ return TargetExecMode::generic ;
21432148
21442149 Operation *parallelOp = (*innermostWrapper)->getParentOp ();
21452150 if (!isa_and_present<ParallelOp>(parallelOp))
2146- return TargetRegionFlags::none ;
2151+ return TargetExecMode::generic ;
21472152
21482153 Operation *teamsOp = parallelOp->getParentOp ();
21492154 if (!isa_and_present<TeamsOp>(teamsOp))
2150- return TargetRegionFlags::none ;
2155+ return TargetExecMode::generic ;
21512156
2152- if (teamsOp->getParentOp () == targetOp.getOperation ())
2153- return TargetRegionFlags::spmd | TargetRegionFlags::trip_count;
2157+ if (teamsOp->getParentOp () == targetOp.getOperation ()) {
2158+ if (hostEvalTripCount)
2159+ *hostEvalTripCount = true ;
2160+ return TargetExecMode::spmd;
2161+ }
21542162 }
21552163 // Detect target-teams-distribute[-simd] and target-teams-loop.
21562164 else if (isa<DistributeOp, LoopOp>(innermostWrapper)) {
21572165 Operation *teamsOp = (*innermostWrapper)->getParentOp ();
21582166 if (!isa_and_present<TeamsOp>(teamsOp))
2159- return TargetRegionFlags::none ;
2167+ return TargetExecMode::generic ;
21602168
21612169 if (teamsOp->getParentOp () != targetOp.getOperation ())
2162- return TargetRegionFlags::none;
2170+ return TargetExecMode::generic;
2171+
2172+ if (hostEvalTripCount)
2173+ *hostEvalTripCount = true ;
21632174
21642175 if (isa<LoopOp>(innermostWrapper))
2165- return TargetRegionFlags ::spmd | TargetRegionFlags::trip_count ;
2176+ return TargetExecMode ::spmd;
21662177
2167- return TargetRegionFlags::trip_count ;
2178+ return TargetExecMode::generic ;
21682179 }
21692180 // Detect target-parallel-wsloop[-simd].
21702181 else if (isa<WsloopOp>(innermostWrapper)) {
21712182 Operation *parallelOp = (*innermostWrapper)->getParentOp ();
21722183 if (!isa_and_present<ParallelOp>(parallelOp))
2173- return TargetRegionFlags::none ;
2184+ return TargetExecMode::generic ;
21742185
21752186 if (parallelOp->getParentOp () == targetOp.getOperation ())
2176- return TargetRegionFlags ::spmd;
2187+ return TargetExecMode ::spmd;
21772188 }
21782189
2179- return TargetRegionFlags::none ;
2190+ return TargetExecMode::generic ;
21802191}
21812192
21822193// ===----------------------------------------------------------------------===//
0 commit comments