@@ -1970,7 +1970,7 @@ LogicalResult TargetOp::verifyRegions() {
1970
1970
}
1971
1971
1972
1972
static Operation *
1973
- findCapturedOmpOp (Operation *rootOp, bool checkSingleMandatoryExec,
1973
+ findCapturedOmpOp (Operation *rootOp,
1974
1974
llvm::function_ref<bool (Operation *)> siblingAllowedFn) {
1975
1975
assert (rootOp && " expected valid operation" );
1976
1976
@@ -1998,19 +1998,17 @@ findCapturedOmpOp(Operation *rootOp, bool checkSingleMandatoryExec,
1998
1998
// (i.e. its block's successors can reach it) or if it's not guaranteed to
1999
1999
// be executed before all exits of the region (i.e. it doesn't dominate all
2000
2000
// blocks with no successors reachable from the entry block).
2001
- if (checkSingleMandatoryExec) {
2002
- Region *parentRegion = op->getParentRegion ();
2003
- Block *parentBlock = op->getBlock ();
2004
-
2005
- for (Block *successor : parentBlock->getSuccessors ())
2006
- if (successor->isReachable (parentBlock))
2007
- return WalkResult::interrupt ();
2008
-
2009
- for (Block &block : *parentRegion)
2010
- if (domInfo.isReachableFromEntry (&block) && block.hasNoSuccessors () &&
2011
- !domInfo.dominates (parentBlock, &block))
2012
- return WalkResult::interrupt ();
2013
- }
2001
+ Region *parentRegion = op->getParentRegion ();
2002
+ Block *parentBlock = op->getBlock ();
2003
+
2004
+ for (Block *successor : parentBlock->getSuccessors ())
2005
+ if (successor->isReachable (parentBlock))
2006
+ return WalkResult::interrupt ();
2007
+
2008
+ for (Block &block : *parentRegion)
2009
+ if (domInfo.isReachableFromEntry (&block) && block.hasNoSuccessors () &&
2010
+ !domInfo.dominates (parentBlock, &block))
2011
+ return WalkResult::interrupt ();
2014
2012
2015
2013
// Don't capture this op if it has a not-allowed sibling, and stop recursing
2016
2014
// into nested operations.
@@ -2033,27 +2031,25 @@ Operation *TargetOp::getInnermostCapturedOmpOp() {
2033
2031
2034
2032
// Only allow OpenMP terminators and non-OpenMP ops that have known memory
2035
2033
// effects, but don't include a memory write effect.
2036
- return findCapturedOmpOp (
2037
- *this , /* checkSingleMandatoryExec=*/ true , [&](Operation *sibling) {
2038
- if (!sibling)
2039
- return false ;
2040
-
2041
- if (ompDialect == sibling->getDialect ())
2042
- return sibling->hasTrait <OpTrait::IsTerminator>();
2043
-
2044
- if (auto memOp = dyn_cast<MemoryEffectOpInterface>(sibling)) {
2045
- SmallVector<SideEffects::EffectInstance<MemoryEffects::Effect>, 4 >
2046
- effects;
2047
- memOp.getEffects (effects);
2048
- return !llvm::any_of (
2049
- effects, [&](MemoryEffects::EffectInstance &effect) {
2050
- return isa<MemoryEffects::Write>(effect.getEffect ()) &&
2051
- isa<SideEffects::AutomaticAllocationScopeResource>(
2052
- effect.getResource ());
2053
- });
2054
- }
2055
- return true ;
2034
+ return findCapturedOmpOp (*this , [&](Operation *sibling) {
2035
+ if (!sibling)
2036
+ return false ;
2037
+
2038
+ if (ompDialect == sibling->getDialect ())
2039
+ return sibling->hasTrait <OpTrait::IsTerminator>();
2040
+
2041
+ if (auto memOp = dyn_cast<MemoryEffectOpInterface>(sibling)) {
2042
+ SmallVector<SideEffects::EffectInstance<MemoryEffects::Effect>, 4 >
2043
+ effects;
2044
+ memOp.getEffects (effects);
2045
+ return !llvm::any_of (effects, [&](MemoryEffects::EffectInstance &effect) {
2046
+ return isa<MemoryEffects::Write>(effect.getEffect ()) &&
2047
+ isa<SideEffects::AutomaticAllocationScopeResource>(
2048
+ effect.getResource ());
2056
2049
});
2050
+ }
2051
+ return true ;
2052
+ });
2057
2053
}
2058
2054
2059
2055
TargetRegionFlags TargetOp::getKernelExecFlags (Operation *capturedOp) {
@@ -2114,33 +2110,23 @@ TargetRegionFlags TargetOp::getKernelExecFlags(Operation *capturedOp) {
2114
2110
if (isa<LoopOp>(innermostWrapper))
2115
2111
return TargetRegionFlags::spmd | TargetRegionFlags::trip_count;
2116
2112
2117
- // Find single immediately nested captured omp.parallel and add spmd flag
2118
- // (generic-spmd case).
2113
+ // Add spmd flag if there's a nested omp.parallel (generic-spmd case).
2119
2114
//
2120
2115
// TODO: This shouldn't have to be done here, as it is too easy to break.
2121
2116
// The openmp-opt pass should be updated to be able to promote kernels like
2122
2117
// this from "Generic" to "Generic-SPMD". However, the use of the
2123
2118
// `kmpc_distribute_static_loop` family of functions produced by the
2124
2119
// OMPIRBuilder for these kernels prevents that from working.
2125
- Dialect *ompDialect = targetOp->getDialect ();
2126
- Operation *nestedCapture = findCapturedOmpOp (
2127
- capturedOp, /* checkSingleMandatoryExec=*/ false ,
2128
- [&](Operation *sibling) {
2129
- return sibling && (ompDialect != sibling->getDialect () ||
2130
- sibling->hasTrait <OpTrait::IsTerminator>());
2131
- });
2120
+ bool hasParallel = capturedOp
2121
+ ->walk <WalkOrder::PreOrder>([](ParallelOp) {
2122
+ return WalkResult::interrupt ();
2123
+ })
2124
+ .wasInterrupted ();
2132
2125
2133
2126
TargetRegionFlags result =
2134
2127
TargetRegionFlags::generic | TargetRegionFlags::trip_count;
2135
2128
2136
- if (!nestedCapture)
2137
- return result;
2138
-
2139
- while (nestedCapture->getParentOp () != capturedOp)
2140
- nestedCapture = nestedCapture->getParentOp ();
2141
-
2142
- return isa<ParallelOp>(nestedCapture) ? result | TargetRegionFlags::spmd
2143
- : result;
2129
+ return hasParallel ? result | TargetRegionFlags::spmd : result;
2144
2130
}
2145
2131
// Detect target-parallel-wsloop[-simd].
2146
2132
else if (isa<WsloopOp>(innermostWrapper)) {
0 commit comments