@@ -1954,7 +1954,7 @@ LogicalResult TargetOp::verifyRegions() {
19541954}
19551955
19561956static Operation *
1957- findCapturedOmpOp (Operation *rootOp,
1957+ findCapturedOmpOp (Operation *rootOp, bool checkSingleMandatoryExec,
19581958 llvm::function_ref<bool (Operation *)> siblingAllowedFn) {
19591959 assert (rootOp && " expected valid operation" );
19601960
@@ -1982,17 +1982,19 @@ findCapturedOmpOp(Operation *rootOp,
19821982 // (i.e. its block's successors can reach it) or if it's not guaranteed to
19831983 // be executed before all exits of the region (i.e. it doesn't dominate all
19841984 // blocks with no successors reachable from the entry block).
1985- Region *parentRegion = op->getParentRegion ();
1986- Block *parentBlock = op->getBlock ();
1987-
1988- for (Block *successor : parentBlock->getSuccessors ())
1989- if (successor->isReachable (parentBlock))
1990- return WalkResult::interrupt ();
1991-
1992- for (Block &block : *parentRegion)
1993- if (domInfo.isReachableFromEntry (&block) && block.hasNoSuccessors () &&
1994- !domInfo.dominates (parentBlock, &block))
1995- return WalkResult::interrupt ();
1985+ if (checkSingleMandatoryExec) {
1986+ Region *parentRegion = op->getParentRegion ();
1987+ Block *parentBlock = op->getBlock ();
1988+
1989+ for (Block *successor : parentBlock->getSuccessors ())
1990+ if (successor->isReachable (parentBlock))
1991+ return WalkResult::interrupt ();
1992+
1993+ for (Block &block : *parentRegion)
1994+ if (domInfo.isReachableFromEntry (&block) && block.hasNoSuccessors () &&
1995+ !domInfo.dominates (parentBlock, &block))
1996+ return WalkResult::interrupt ();
1997+ }
19961998
19971999 // Don't capture this op if it has a not-allowed sibling, and stop recursing
19982000 // into nested operations.
@@ -2015,25 +2017,27 @@ Operation *TargetOp::getInnermostCapturedOmpOp() {
20152017
20162018 // Only allow OpenMP terminators and non-OpenMP ops that have known memory
20172019 // effects, but don't include a memory write effect.
2018- return findCapturedOmpOp (*this , [&](Operation *sibling) {
2019- if (!sibling)
2020- return false ;
2021-
2022- if (ompDialect == sibling->getDialect ())
2023- return sibling->hasTrait <OpTrait::IsTerminator>();
2024-
2025- if (auto memOp = dyn_cast<MemoryEffectOpInterface>(sibling)) {
2026- SmallVector<SideEffects::EffectInstance<MemoryEffects::Effect>, 4 >
2027- effects;
2028- memOp.getEffects (effects);
2029- return !llvm::any_of (effects, [&](MemoryEffects::EffectInstance &effect) {
2030- return isa<MemoryEffects::Write>(effect.getEffect ()) &&
2031- isa<SideEffects::AutomaticAllocationScopeResource>(
2032- effect.getResource ());
2020+ return findCapturedOmpOp (
2021+ *this , /* checkSingleMandatoryExec=*/ true , [&](Operation *sibling) {
2022+ if (!sibling)
2023+ return false ;
2024+
2025+ if (ompDialect == sibling->getDialect ())
2026+ return sibling->hasTrait <OpTrait::IsTerminator>();
2027+
2028+ if (auto memOp = dyn_cast<MemoryEffectOpInterface>(sibling)) {
2029+ SmallVector<SideEffects::EffectInstance<MemoryEffects::Effect>, 4 >
2030+ effects;
2031+ memOp.getEffects (effects);
2032+ return !llvm::any_of (
2033+ effects, [&](MemoryEffects::EffectInstance &effect) {
2034+ return isa<MemoryEffects::Write>(effect.getEffect ()) &&
2035+ isa<SideEffects::AutomaticAllocationScopeResource>(
2036+ effect.getResource ());
2037+ });
2038+ }
2039+ return true ;
20332040 });
2034- }
2035- return true ;
2036- });
20372041}
20382042
20392043TargetRegionFlags TargetOp::getKernelExecFlags (Operation *capturedOp) {
@@ -2108,8 +2112,9 @@ TargetRegionFlags TargetOp::getKernelExecFlags(Operation *capturedOp) {
21082112 // `kmpc_distribute_static_loop` family of functions produced by the
21092113 // OMPIRBuilder for these kernels prevents that from working.
21102114 Dialect *ompDialect = targetOp->getDialect ();
2111- Operation *nestedCapture =
2112- findCapturedOmpOp (capturedOp, [&](Operation *sibling) {
2115+ Operation *nestedCapture = findCapturedOmpOp (
2116+ capturedOp, /* checkSingleMandatoryExec=*/ false ,
2117+ [&](Operation *sibling) {
21132118 return sibling && (ompDialect != sibling->getDialect () ||
21142119 sibling->hasTrait <OpTrait::IsTerminator>());
21152120 });
0 commit comments