@@ -1970,7 +1970,7 @@ LogicalResult TargetOp::verifyRegions() {
19701970}
19711971
19721972static Operation *
1973- findCapturedOmpOp (Operation *rootOp, bool checkSingleMandatoryExec,
1973+ findCapturedOmpOp (Operation *rootOp,
19741974 llvm::function_ref<bool (Operation *)> siblingAllowedFn) {
19751975 assert (rootOp && " expected valid operation" );
19761976
@@ -1998,19 +1998,17 @@ findCapturedOmpOp(Operation *rootOp, bool checkSingleMandatoryExec,
19981998 // (i.e. its block's successors can reach it) or if it's not guaranteed to
19991999 // be executed before all exits of the region (i.e. it doesn't dominate all
20002000 // blocks with no successors reachable from the entry block).
2001- if (checkSingleMandatoryExec) {
2002- Region *parentRegion = op->getParentRegion ();
2003- Block *parentBlock = op->getBlock ();
2004-
2005- for (Block *successor : parentBlock->getSuccessors ())
2006- if (successor->isReachable (parentBlock))
2007- return WalkResult::interrupt ();
2008-
2009- for (Block &block : *parentRegion)
2010- if (domInfo.isReachableFromEntry (&block) && block.hasNoSuccessors () &&
2011- !domInfo.dominates (parentBlock, &block))
2012- return WalkResult::interrupt ();
2013- }
2001+ Region *parentRegion = op->getParentRegion ();
2002+ Block *parentBlock = op->getBlock ();
2003+
2004+ for (Block *successor : parentBlock->getSuccessors ())
2005+ if (successor->isReachable (parentBlock))
2006+ return WalkResult::interrupt ();
2007+
2008+ for (Block &block : *parentRegion)
2009+ if (domInfo.isReachableFromEntry (&block) && block.hasNoSuccessors () &&
2010+ !domInfo.dominates (parentBlock, &block))
2011+ return WalkResult::interrupt ();
20142012
20152013 // Don't capture this op if it has a not-allowed sibling, and stop recursing
20162014 // into nested operations.
@@ -2033,27 +2031,25 @@ Operation *TargetOp::getInnermostCapturedOmpOp() {
20332031
20342032 // Only allow OpenMP terminators and non-OpenMP ops that have known memory
20352033 // effects, but don't include a memory write effect.
2036- return findCapturedOmpOp (
2037- *this , /* checkSingleMandatoryExec=*/ true , [&](Operation *sibling) {
2038- if (!sibling)
2039- return false ;
2040-
2041- if (ompDialect == sibling->getDialect ())
2042- return sibling->hasTrait <OpTrait::IsTerminator>();
2043-
2044- if (auto memOp = dyn_cast<MemoryEffectOpInterface>(sibling)) {
2045- SmallVector<SideEffects::EffectInstance<MemoryEffects::Effect>, 4 >
2046- effects;
2047- memOp.getEffects (effects);
2048- return !llvm::any_of (
2049- effects, [&](MemoryEffects::EffectInstance &effect) {
2050- return isa<MemoryEffects::Write>(effect.getEffect ()) &&
2051- isa<SideEffects::AutomaticAllocationScopeResource>(
2052- effect.getResource ());
2053- });
2054- }
2055- return true ;
2034+ return findCapturedOmpOp (*this , [&](Operation *sibling) {
2035+ if (!sibling)
2036+ return false ;
2037+
2038+ if (ompDialect == sibling->getDialect ())
2039+ return sibling->hasTrait <OpTrait::IsTerminator>();
2040+
2041+ if (auto memOp = dyn_cast<MemoryEffectOpInterface>(sibling)) {
2042+ SmallVector<SideEffects::EffectInstance<MemoryEffects::Effect>, 4 >
2043+ effects;
2044+ memOp.getEffects (effects);
2045+ return !llvm::any_of (effects, [&](MemoryEffects::EffectInstance &effect) {
2046+ return isa<MemoryEffects::Write>(effect.getEffect ()) &&
2047+ isa<SideEffects::AutomaticAllocationScopeResource>(
2048+ effect.getResource ());
20562049 });
2050+ }
2051+ return true ;
2052+ });
20572053}
20582054
20592055TargetRegionFlags TargetOp::getKernelExecFlags (Operation *capturedOp) {
@@ -2114,33 +2110,23 @@ TargetRegionFlags TargetOp::getKernelExecFlags(Operation *capturedOp) {
21142110 if (isa<LoopOp>(innermostWrapper))
21152111 return TargetRegionFlags::spmd | TargetRegionFlags::trip_count;
21162112
2117- // Find single immediately nested captured omp.parallel and add spmd flag
2118- // (generic-spmd case).
2113+ // Add spmd flag if there's a nested omp.parallel (generic-spmd case).
21192114 //
21202115 // TODO: This shouldn't have to be done here, as it is too easy to break.
21212116 // The openmp-opt pass should be updated to be able to promote kernels like
21222117 // this from "Generic" to "Generic-SPMD". However, the use of the
21232118 // `kmpc_distribute_static_loop` family of functions produced by the
21242119 // OMPIRBuilder for these kernels prevents that from working.
2125- Dialect *ompDialect = targetOp->getDialect ();
2126- Operation *nestedCapture = findCapturedOmpOp (
2127- capturedOp, /* checkSingleMandatoryExec=*/ false ,
2128- [&](Operation *sibling) {
2129- return sibling && (ompDialect != sibling->getDialect () ||
2130- sibling->hasTrait <OpTrait::IsTerminator>());
2131- });
2120+ bool hasParallel = capturedOp
2121+ ->walk <WalkOrder::PreOrder>([](ParallelOp) {
2122+ return WalkResult::interrupt ();
2123+ })
2124+ .wasInterrupted ();
21322125
21332126 TargetRegionFlags result =
21342127 TargetRegionFlags::generic | TargetRegionFlags::trip_count;
21352128
2136- if (!nestedCapture)
2137- return result;
2138-
2139- while (nestedCapture->getParentOp () != capturedOp)
2140- nestedCapture = nestedCapture->getParentOp ();
2141-
2142- return isa<ParallelOp>(nestedCapture) ? result | TargetRegionFlags::spmd
2143- : result;
2129+ return hasParallel ? result | TargetRegionFlags::spmd : result;
21442130 }
21452131 // Detect target-parallel-wsloop[-simd].
21462132 else if (isa<WsloopOp>(innermostWrapper)) {
0 commit comments