@@ -2205,8 +2205,9 @@ LogicalResult TargetOp::verifyRegions() {
22052205 return emitError (" target containing multiple 'omp.teams' nested ops" );
22062206
22072207 // Check that host_eval values are only used in legal ways.
2208+ bool hostEvalTripCount;
22082209 Operation *capturedOp = getInnermostCapturedOmpOp ();
2209- TargetRegionFlags execFlags = getKernelExecFlags (capturedOp);
2210+ TargetExecMode execMode = getKernelExecFlags (capturedOp, &hostEvalTripCount );
22102211 for (Value hostEvalArg :
22112212 cast<BlockArgOpenMPOpInterface>(getOperation ()).getHostEvalBlockArgs ()) {
22122213 for (Operation *user : hostEvalArg.getUsers ()) {
@@ -2221,7 +2222,7 @@ LogicalResult TargetOp::verifyRegions() {
22212222 " and 'thread_limit' in 'omp.teams'" ;
22222223 }
22232224 if (auto parallelOp = dyn_cast<ParallelOp>(user)) {
2224- if (bitEnumContainsAny (execFlags, TargetRegionFlags ::spmd) &&
2225+ if (execMode == TargetExecMode ::spmd &&
22252226 parallelOp->isAncestor (capturedOp) &&
22262227 hostEvalArg == parallelOp.getNumThreads ())
22272228 continue ;
@@ -2231,8 +2232,7 @@ LogicalResult TargetOp::verifyRegions() {
22312232 " 'omp.parallel' when representing target SPMD" ;
22322233 }
22332234 if (auto loopNestOp = dyn_cast<LoopNestOp>(user)) {
2234- if (bitEnumContainsAny (execFlags, TargetRegionFlags::trip_count) &&
2235- loopNestOp.getOperation () == capturedOp &&
2235+ if (hostEvalTripCount && loopNestOp.getOperation () == capturedOp &&
22362236 (llvm::is_contained (loopNestOp.getLoopLowerBounds (), hostEvalArg) ||
22372237 llvm::is_contained (loopNestOp.getLoopUpperBounds (), hostEvalArg) ||
22382238 llvm::is_contained (loopNestOp.getLoopSteps (), hostEvalArg)))
@@ -2362,7 +2362,9 @@ static bool canPromoteToNoLoop(Operation *capturedOp, TeamsOp teamsOp,
23622362 ompFlags.getAssumeThreadsOversubscription ();
23632363}
23642364
2365- TargetRegionFlags TargetOp::getKernelExecFlags (Operation *capturedOp) {
2365+ TargetExecMode TargetOp::getKernelExecFlags (Operation *capturedOp,
2366+ bool *hostEvalTripCount) {
2367+ // TODO: Support detection of bare kernel mode.
23662368 // A non-null captured op is only valid if it resides inside of a TargetOp
23672369 // and is the result of calling getInnermostCapturedOmpOp() on it.
23682370 TargetOp targetOp =
@@ -2371,9 +2373,12 @@ TargetRegionFlags TargetOp::getKernelExecFlags(Operation *capturedOp) {
23712373 (targetOp && targetOp.getInnermostCapturedOmpOp () == capturedOp)) &&
23722374 " unexpected captured op" );
23732375
2376+ if (hostEvalTripCount)
2377+ *hostEvalTripCount = false ;
2378+
23742379 // If it's not capturing a loop, it's a default target region.
23752380 if (!isa_and_present<LoopNestOp>(capturedOp))
2376- return TargetRegionFlags::none ;
2381+ return TargetExecMode::generic ;
23772382
23782383 // Get the innermost non-simd loop wrapper.
23792384 SmallVector<LoopWrapperInterface> loopWrappers;
@@ -2386,59 +2391,63 @@ TargetRegionFlags TargetOp::getKernelExecFlags(Operation *capturedOp) {
23862391
23872392 auto numWrappers = std::distance (innermostWrapper, loopWrappers.end ());
23882393 if (numWrappers != 1 && numWrappers != 2 )
2389- return TargetRegionFlags::none ;
2394+ return TargetExecMode::generic ;
23902395
23912396 // Detect target-teams-distribute-parallel-wsloop[-simd].
23922397 if (numWrappers == 2 ) {
23932398 WsloopOp *wsloopOp = dyn_cast<WsloopOp>(innermostWrapper);
23942399 if (!wsloopOp)
2395- return TargetRegionFlags::none ;
2400+ return TargetExecMode::generic ;
23962401
23972402 innermostWrapper = std::next (innermostWrapper);
23982403 if (!isa<DistributeOp>(innermostWrapper))
2399- return TargetRegionFlags::none ;
2404+ return TargetExecMode::generic ;
24002405
24012406 Operation *parallelOp = (*innermostWrapper)->getParentOp ();
24022407 if (!isa_and_present<ParallelOp>(parallelOp))
2403- return TargetRegionFlags::none ;
2408+ return TargetExecMode::generic ;
24042409
24052410 TeamsOp teamsOp = dyn_cast<TeamsOp>(parallelOp->getParentOp ());
24062411 if (!teamsOp)
2407- return TargetRegionFlags::none ;
2412+ return TargetExecMode::generic ;
24082413
24092414 if (teamsOp->getParentOp () == targetOp.getOperation ()) {
2410- TargetRegionFlags result =
2411- TargetRegionFlags::spmd | TargetRegionFlags::trip_count;
2415+ TargetExecMode result = TargetExecMode::spmd;
24122416 if (canPromoteToNoLoop (capturedOp, teamsOp, wsloopOp))
2413- result = result | TargetRegionFlags::no_loop;
2417+ result = TargetExecMode::no_loop;
2418+ if (hostEvalTripCount)
2419+ *hostEvalTripCount = true ;
24142420 return result;
24152421 }
24162422 }
24172423 // Detect target-teams-distribute[-simd] and target-teams-loop.
24182424 else if (isa<DistributeOp, LoopOp>(innermostWrapper)) {
24192425 Operation *teamsOp = (*innermostWrapper)->getParentOp ();
24202426 if (!isa_and_present<TeamsOp>(teamsOp))
2421- return TargetRegionFlags::none ;
2427+ return TargetExecMode::generic ;
24222428
24232429 if (teamsOp->getParentOp () != targetOp.getOperation ())
2424- return TargetRegionFlags::none;
2430+ return TargetExecMode::generic;
2431+
2432+ if (hostEvalTripCount)
2433+ *hostEvalTripCount = true ;
24252434
24262435 if (isa<LoopOp>(innermostWrapper))
2427- return TargetRegionFlags ::spmd | TargetRegionFlags::trip_count ;
2436+ return TargetExecMode ::spmd;
24282437
2429- return TargetRegionFlags::trip_count ;
2438+ return TargetExecMode::generic ;
24302439 }
24312440 // Detect target-parallel-wsloop[-simd].
24322441 else if (isa<WsloopOp>(innermostWrapper)) {
24332442 Operation *parallelOp = (*innermostWrapper)->getParentOp ();
24342443 if (!isa_and_present<ParallelOp>(parallelOp))
2435- return TargetRegionFlags::none ;
2444+ return TargetExecMode::generic ;
24362445
24372446 if (parallelOp->getParentOp () == targetOp.getOperation ())
2438- return TargetRegionFlags ::spmd;
2447+ return TargetExecMode ::spmd;
24392448 }
24402449
2441- return TargetRegionFlags::none ;
2450+ return TargetExecMode::generic ;
24422451}
24432452
24442453// ===----------------------------------------------------------------------===//
0 commit comments