@@ -1908,8 +1908,8 @@ LogicalResult TargetOp::verifyRegions() {
19081908 return emitError (" target containing multiple 'omp.teams' nested ops" );
19091909
19101910 // Check that host_eval values are only used in legal ways.
1911- llvm::omp::OMPTgtExecModeFlags execFlags =
1912- getKernelExecFlags (getInnermostCapturedOmpOp () );
1911+ Operation *capturedOp = getInnermostCapturedOmpOp ();
1912+ TargetRegionFlags execFlags = getKernelExecFlags (capturedOp );
19131913 for (Value hostEvalArg :
19141914 cast<BlockArgOpenMPOpInterface>(getOperation ()).getHostEvalBlockArgs ()) {
19151915 for (Operation *user : hostEvalArg.getUsers ()) {
@@ -1924,7 +1924,8 @@ LogicalResult TargetOp::verifyRegions() {
19241924 " and 'thread_limit' in 'omp.teams'" ;
19251925 }
19261926 if (auto parallelOp = dyn_cast<ParallelOp>(user)) {
1927- if (execFlags == llvm::omp::OMP_TGT_EXEC_MODE_SPMD &&
1927+ if (bitEnumContainsAny (execFlags, TargetRegionFlags::spmd) &&
1928+ parallelOp->isAncestor (capturedOp) &&
19281929 hostEvalArg == parallelOp.getNumThreads ())
19291930 continue ;
19301931
@@ -1933,15 +1934,16 @@ LogicalResult TargetOp::verifyRegions() {
19331934 " 'omp.parallel' when representing target SPMD" ;
19341935 }
19351936 if (auto loopNestOp = dyn_cast<LoopNestOp>(user)) {
1936- if (execFlags != llvm::omp::OMP_TGT_EXEC_MODE_GENERIC &&
1937+ if (bitEnumContainsAny (execFlags, TargetRegionFlags::trip_count) &&
1938+ loopNestOp.getOperation () == capturedOp &&
19371939 (llvm::is_contained (loopNestOp.getLoopLowerBounds (), hostEvalArg) ||
19381940 llvm::is_contained (loopNestOp.getLoopUpperBounds (), hostEvalArg) ||
19391941 llvm::is_contained (loopNestOp.getLoopSteps (), hostEvalArg)))
19401942 continue ;
19411943
19421944 return emitOpError () << " host_eval argument only legal as loop bounds "
1943- " and steps in 'omp.loop_nest' when "
1944- " representing target SPMD or Generic-SPMD " ;
1945+ " and steps in 'omp.loop_nest' when trip count "
1946+ " must be evaluated in the host " ;
19451947 }
19461948
19471949 return emitOpError () << " host_eval argument illegal use in '"
@@ -1951,42 +1953,21 @@ LogicalResult TargetOp::verifyRegions() {
19511953 return success ();
19521954}
19531955
1954- // / Only allow OpenMP terminators and non-OpenMP ops that have known memory
1955- // / effects, but don't include a memory write effect.
1956- static bool siblingAllowedInCapture (Operation *op) {
1957- if (!op)
1958- return false ;
1956+ static Operation *
1957+ findCapturedOmpOp (Operation *rootOp, bool checkSingleMandatoryExec,
1958+ llvm::function_ref<bool (Operation *)> siblingAllowedFn) {
1959+ assert (rootOp && " expected valid operation" );
19591960
1960- bool isOmpDialect =
1961- op->getContext ()->getLoadedDialect <omp::OpenMPDialect>() ==
1962- op->getDialect ();
1963-
1964- if (isOmpDialect)
1965- return op->hasTrait <OpTrait::IsTerminator>();
1966-
1967- if (auto memOp = dyn_cast<MemoryEffectOpInterface>(op)) {
1968- SmallVector<SideEffects::EffectInstance<MemoryEffects::Effect>, 4 > effects;
1969- memOp.getEffects (effects);
1970- return !llvm::any_of (effects, [&](MemoryEffects::EffectInstance &effect) {
1971- return isa<MemoryEffects::Write>(effect.getEffect ()) &&
1972- isa<SideEffects::AutomaticAllocationScopeResource>(
1973- effect.getResource ());
1974- });
1975- }
1976- return true ;
1977- }
1978-
1979- Operation *TargetOp::getInnermostCapturedOmpOp () {
1980- Dialect *ompDialect = (*this )->getDialect ();
1961+ Dialect *ompDialect = rootOp->getDialect ();
19811962 Operation *capturedOp = nullptr ;
19821963 DominanceInfo domInfo;
19831964
19841965 // Process in pre-order to check operations from outermost to innermost,
19851966 // ensuring we only enter the region of an operation if it meets the criteria
19861967 // for being captured. We stop the exploration of nested operations as soon as
19871968 // we process a region holding no operations to be captured.
1988- walk<WalkOrder::PreOrder>([&](Operation *op) {
1989- if (op == * this )
1969+ rootOp-> walk <WalkOrder::PreOrder>([&](Operation *op) {
1970+ if (op == rootOp )
19901971 return WalkResult::advance ();
19911972
19921973 // Ignore operations of other dialects or omp operations with no regions,
@@ -2001,22 +1982,24 @@ Operation *TargetOp::getInnermostCapturedOmpOp() {
20011982 // (i.e. its block's successors can reach it) or if it's not guaranteed to
20021983 // be executed before all exits of the region (i.e. it doesn't dominate all
20031984 // blocks with no successors reachable from the entry block).
2004- Region *parentRegion = op->getParentRegion ();
2005- Block *parentBlock = op->getBlock ();
2006-
2007- for (Block *successor : parentBlock->getSuccessors ())
2008- if (successor->isReachable (parentBlock))
2009- return WalkResult::interrupt ();
2010-
2011- for (Block &block : *parentRegion)
2012- if (domInfo.isReachableFromEntry (&block) && block.hasNoSuccessors () &&
2013- !domInfo.dominates (parentBlock, &block))
2014- return WalkResult::interrupt ();
1985+ if (checkSingleMandatoryExec) {
1986+ Region *parentRegion = op->getParentRegion ();
1987+ Block *parentBlock = op->getBlock ();
1988+
1989+ for (Block *successor : parentBlock->getSuccessors ())
1990+ if (successor->isReachable (parentBlock))
1991+ return WalkResult::interrupt ();
1992+
1993+ for (Block &block : *parentRegion)
1994+ if (domInfo.isReachableFromEntry (&block) && block.hasNoSuccessors () &&
1995+ !domInfo.dominates (parentBlock, &block))
1996+ return WalkResult::interrupt ();
1997+ }
20151998
20161999 // Don't capture this op if it has a not-allowed sibling, and stop recursing
20172000 // into nested operations.
20182001 for (Operation &sibling : op->getParentRegion ()->getOps ())
2019- if (&sibling != op && !siblingAllowedInCapture (&sibling))
2002+ if (&sibling != op && !siblingAllowedFn (&sibling))
20202003 return WalkResult::interrupt ();
20212004
20222005 // Don't continue capturing nested operations if we reach an omp.loop_nest.
@@ -2029,10 +2012,35 @@ Operation *TargetOp::getInnermostCapturedOmpOp() {
20292012 return capturedOp;
20302013}
20312014
2032- llvm::omp::OMPTgtExecModeFlags
2033- TargetOp::getKernelExecFlags (Operation *capturedOp) {
2034- using namespace llvm ::omp;
2015+ Operation *TargetOp::getInnermostCapturedOmpOp () {
2016+ auto *ompDialect = getContext ()->getLoadedDialect <omp::OpenMPDialect>();
2017+
2018+ // Only allow OpenMP terminators and non-OpenMP ops that have known memory
2019+ // effects, but don't include a memory write effect.
2020+ return findCapturedOmpOp (
2021+ *this , /* checkSingleMandatoryExec=*/ true , [&](Operation *sibling) {
2022+ if (!sibling)
2023+ return false ;
2024+
2025+ if (ompDialect == sibling->getDialect ())
2026+ return sibling->hasTrait <OpTrait::IsTerminator>();
2027+
2028+ if (auto memOp = dyn_cast<MemoryEffectOpInterface>(sibling)) {
2029+ SmallVector<SideEffects::EffectInstance<MemoryEffects::Effect>, 4 >
2030+ effects;
2031+ memOp.getEffects (effects);
2032+ return !llvm::any_of (
2033+ effects, [&](MemoryEffects::EffectInstance &effect) {
2034+ return isa<MemoryEffects::Write>(effect.getEffect ()) &&
2035+ isa<SideEffects::AutomaticAllocationScopeResource>(
2036+ effect.getResource ());
2037+ });
2038+ }
2039+ return true ;
2040+ });
2041+ }
20352042
2043+ TargetRegionFlags TargetOp::getKernelExecFlags (Operation *capturedOp) {
20362044 // A non-null captured op is only valid if it resides inside of a TargetOp
20372045 // and is the result of calling getInnermostCapturedOmpOp() on it.
20382046 TargetOp targetOp =
@@ -2041,60 +2049,94 @@ TargetOp::getKernelExecFlags(Operation *capturedOp) {
20412049 (targetOp && targetOp.getInnermostCapturedOmpOp () == capturedOp)) &&
20422050 " unexpected captured op" );
20432051
2044- // Make sure this region is capturing a loop. Otherwise, it's a generic
2045- // kernel.
2052+ // If it's not capturing a loop, it's a default target region.
20462053 if (!isa_and_present<LoopNestOp>(capturedOp))
2047- return OMP_TGT_EXEC_MODE_GENERIC ;
2054+ return TargetRegionFlags::generic ;
20482055
2049- SmallVector<LoopWrapperInterface> wrappers;
2050- cast<LoopNestOp>(capturedOp).gatherWrappers (wrappers);
2051- assert (!wrappers.empty ());
2056+ // Get the innermost non-simd loop wrapper.
2057+ SmallVector<LoopWrapperInterface> loopWrappers;
2058+ cast<LoopNestOp>(capturedOp).gatherWrappers (loopWrappers);
2059+ assert (!loopWrappers.empty ());
20522060
2053- // Ignore optional SIMD leaf construct.
2054- auto *innermostWrapper = wrappers.begin ();
2061+ LoopWrapperInterface *innermostWrapper = loopWrappers.begin ();
20552062 if (isa<SimdOp>(innermostWrapper))
20562063 innermostWrapper = std::next (innermostWrapper);
20572064
2058- long numWrappers = std::distance (innermostWrapper, wrappers.end ());
2059-
2060- // Detect Generic-SPMD: target-teams-distribute[-simd].
2061- // Detect SPMD: target-teams-loop.
2062- if (numWrappers == 1 ) {
2063- if (!isa<DistributeOp, LoopOp>(innermostWrapper))
2064- return OMP_TGT_EXEC_MODE_GENERIC;
2065-
2066- Operation *teamsOp = (*innermostWrapper)->getParentOp ();
2067- if (!isa_and_present<TeamsOp>(teamsOp))
2068- return OMP_TGT_EXEC_MODE_GENERIC;
2065+ auto numWrappers = std::distance (innermostWrapper, loopWrappers.end ());
2066+ if (numWrappers != 1 && numWrappers != 2 )
2067+ return TargetRegionFlags::generic;
20692068
2070- if (teamsOp->getParentOp () == targetOp.getOperation ())
2071- return isa<DistributeOp>(innermostWrapper)
2072- ? OMP_TGT_EXEC_MODE_GENERIC_SPMD
2073- : OMP_TGT_EXEC_MODE_SPMD;
2074- }
2075-
2076- // Detect SPMD: target-teams-distribute-parallel-wsloop[-simd].
2069+ // Detect target-teams-distribute-parallel-wsloop[-simd].
20772070 if (numWrappers == 2 ) {
20782071 if (!isa<WsloopOp>(innermostWrapper))
2079- return OMP_TGT_EXEC_MODE_GENERIC ;
2072+ return TargetRegionFlags::generic ;
20802073
20812074 innermostWrapper = std::next (innermostWrapper);
20822075 if (!isa<DistributeOp>(innermostWrapper))
2083- return OMP_TGT_EXEC_MODE_GENERIC ;
2076+ return TargetRegionFlags::generic ;
20842077
20852078 Operation *parallelOp = (*innermostWrapper)->getParentOp ();
20862079 if (!isa_and_present<ParallelOp>(parallelOp))
2087- return OMP_TGT_EXEC_MODE_GENERIC ;
2080+ return TargetRegionFlags::generic ;
20882081
20892082 Operation *teamsOp = parallelOp->getParentOp ();
20902083 if (!isa_and_present<TeamsOp>(teamsOp))
2091- return OMP_TGT_EXEC_MODE_GENERIC ;
2084+ return TargetRegionFlags::generic ;
20922085
20932086 if (teamsOp->getParentOp () == targetOp.getOperation ())
2094- return OMP_TGT_EXEC_MODE_SPMD;
2087+ return TargetRegionFlags::spmd | TargetRegionFlags::trip_count;
2088+ }
2089+ // Detect target-teams-distribute[-simd] and target-teams-loop.
2090+ else if (isa<DistributeOp, LoopOp>(innermostWrapper)) {
2091+ Operation *teamsOp = (*innermostWrapper)->getParentOp ();
2092+ if (!isa_and_present<TeamsOp>(teamsOp))
2093+ return TargetRegionFlags::generic;
2094+
2095+ if (teamsOp->getParentOp () != targetOp.getOperation ())
2096+ return TargetRegionFlags::generic;
2097+
2098+ if (isa<LoopOp>(innermostWrapper))
2099+ return TargetRegionFlags::spmd | TargetRegionFlags::trip_count;
2100+
2101+ // Find single immediately nested captured omp.parallel and add spmd flag
2102+ // (generic-spmd case).
2103+ //
2104+ // TODO: This shouldn't have to be done here, as it is too easy to break.
2105+ // The openmp-opt pass should be updated to be able to promote kernels like
2106+ // this from "Generic" to "Generic-SPMD". However, the use of the
2107+ // `kmpc_distribute_static_loop` family of functions produced by the
2108+ // OMPIRBuilder for these kernels prevents that from working.
2109+ Dialect *ompDialect = targetOp->getDialect ();
2110+ Operation *nestedCapture = findCapturedOmpOp (
2111+ capturedOp, /* checkSingleMandatoryExec=*/ false ,
2112+ [&](Operation *sibling) {
2113+ return sibling && (ompDialect != sibling->getDialect () ||
2114+ sibling->hasTrait <OpTrait::IsTerminator>());
2115+ });
2116+
2117+ TargetRegionFlags result =
2118+ TargetRegionFlags::generic | TargetRegionFlags::trip_count;
2119+
2120+ if (!nestedCapture)
2121+ return result;
2122+
2123+ while (nestedCapture->getParentOp () != capturedOp)
2124+ nestedCapture = nestedCapture->getParentOp ();
2125+
2126+ return isa<ParallelOp>(nestedCapture) ? result | TargetRegionFlags::spmd
2127+ : result;
2128+ }
2129+ // Detect target-parallel-wsloop[-simd].
2130+ else if (isa<WsloopOp>(innermostWrapper)) {
2131+ Operation *parallelOp = (*innermostWrapper)->getParentOp ();
2132+ if (!isa_and_present<ParallelOp>(parallelOp))
2133+ return TargetRegionFlags::generic;
2134+
2135+ if (parallelOp->getParentOp () == targetOp.getOperation ())
2136+ return TargetRegionFlags::spmd;
20952137 }
20962138
2097- return OMP_TGT_EXEC_MODE_GENERIC ;
2139+ return TargetRegionFlags::generic ;
20982140}
20992141
21002142// ===----------------------------------------------------------------------===//
0 commit comments