@@ -2099,10 +2099,9 @@ TargetRegionFlags TargetOp::getKernelExecFlags(Operation *capturedOp) {
20992099 if (teamsOp->getParentOp () != targetOp.getOperation ())
21002100 return TargetRegionFlags::generic;
21012101
2102- TargetRegionFlags result =
2103- TargetRegionFlags::generic | TargetRegionFlags::trip_count;
2104-
2105- // Find single nested parallel-do and add spmd flag (generic-spmd case).
2102+ // Find single immediately nested captured omp.parallel and add spmd flag
2103+ // (generic-spmd case).
2104+ //
21062105 // TODO: This shouldn't have to be done here, as it is too easy to break.
21072106 // The openmp-opt pass should be updated to be able to promote kernels like
21082107 // this from "Generic" to "Generic-SPMD". However, the use of the
@@ -2115,24 +2114,17 @@ TargetRegionFlags TargetOp::getKernelExecFlags(Operation *capturedOp) {
21152114 sibling->hasTrait <OpTrait::IsTerminator>());
21162115 });
21172116
2118- if (!isa_and_present<LoopNestOp>(nestedCapture))
2119- return result;
2120-
2121- int numNestedWrappers;
2122- LoopWrapperInterface *nestedWrapper =
2123- getInnermostWrapper (cast<LoopNestOp>(nestedCapture), numNestedWrappers);
2124-
2125- if (numNestedWrappers != 1 || !isa<WsloopOp>(nestedWrapper))
2126- return result;
2117+ TargetRegionFlags result =
2118+ TargetRegionFlags::generic | TargetRegionFlags::trip_count;
21272119
2128- Operation *parallelOp = (*nestedWrapper)->getParentOp ();
2129- if (!isa_and_present<ParallelOp>(parallelOp))
2120+ if (!nestedCapture)
21302121 return result;
21312122
2132- if (parallelOp ->getParentOp () != capturedOp)
2133- return result ;
2123+ while (nestedCapture ->getParentOp () != capturedOp)
2124+ nestedCapture = nestedCapture-> getParentOp () ;
21342125
2135- return result | TargetRegionFlags::spmd;
2126+ return isa<ParallelOp>(nestedCapture) ? result | TargetRegionFlags::spmd
2127+ : result;
21362128 }
21372129 // Detect target-parallel-wsloop[-simd].
21382130 else if (isa<WsloopOp>(innermostWrapper)) {
0 commit comments