Skip to content

Commit 011e73b

Browse files
authored
[MLIR][OpenMP] Relax requirement to produce Generic-SPMD kernels (llvm#1582)
2 parents f28cf86 + 9fac467 commit 011e73b

File tree

1 file changed

+37
-51
lines changed

1 file changed

+37
-51
lines changed

mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp

Lines changed: 37 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1970,7 +1970,7 @@ LogicalResult TargetOp::verifyRegions() {
19701970
}
19711971

19721972
static Operation *
1973-
findCapturedOmpOp(Operation *rootOp, bool checkSingleMandatoryExec,
1973+
findCapturedOmpOp(Operation *rootOp,
19741974
llvm::function_ref<bool(Operation *)> siblingAllowedFn) {
19751975
assert(rootOp && "expected valid operation");
19761976

@@ -1998,19 +1998,17 @@ findCapturedOmpOp(Operation *rootOp, bool checkSingleMandatoryExec,
19981998
// (i.e. its block's successors can reach it) or if it's not guaranteed to
19991999
// be executed before all exits of the region (i.e. it doesn't dominate all
20002000
// blocks with no successors reachable from the entry block).
2001-
if (checkSingleMandatoryExec) {
2002-
Region *parentRegion = op->getParentRegion();
2003-
Block *parentBlock = op->getBlock();
2004-
2005-
for (Block *successor : parentBlock->getSuccessors())
2006-
if (successor->isReachable(parentBlock))
2007-
return WalkResult::interrupt();
2008-
2009-
for (Block &block : *parentRegion)
2010-
if (domInfo.isReachableFromEntry(&block) && block.hasNoSuccessors() &&
2011-
!domInfo.dominates(parentBlock, &block))
2012-
return WalkResult::interrupt();
2013-
}
2001+
Region *parentRegion = op->getParentRegion();
2002+
Block *parentBlock = op->getBlock();
2003+
2004+
for (Block *successor : parentBlock->getSuccessors())
2005+
if (successor->isReachable(parentBlock))
2006+
return WalkResult::interrupt();
2007+
2008+
for (Block &block : *parentRegion)
2009+
if (domInfo.isReachableFromEntry(&block) && block.hasNoSuccessors() &&
2010+
!domInfo.dominates(parentBlock, &block))
2011+
return WalkResult::interrupt();
20142012

20152013
// Don't capture this op if it has a not-allowed sibling, and stop recursing
20162014
// into nested operations.
@@ -2033,27 +2031,25 @@ Operation *TargetOp::getInnermostCapturedOmpOp() {
20332031

20342032
// Only allow OpenMP terminators and non-OpenMP ops that have known memory
20352033
// effects, but don't include a memory write effect.
2036-
return findCapturedOmpOp(
2037-
*this, /*checkSingleMandatoryExec=*/true, [&](Operation *sibling) {
2038-
if (!sibling)
2039-
return false;
2040-
2041-
if (ompDialect == sibling->getDialect())
2042-
return sibling->hasTrait<OpTrait::IsTerminator>();
2043-
2044-
if (auto memOp = dyn_cast<MemoryEffectOpInterface>(sibling)) {
2045-
SmallVector<SideEffects::EffectInstance<MemoryEffects::Effect>, 4>
2046-
effects;
2047-
memOp.getEffects(effects);
2048-
return !llvm::any_of(
2049-
effects, [&](MemoryEffects::EffectInstance &effect) {
2050-
return isa<MemoryEffects::Write>(effect.getEffect()) &&
2051-
isa<SideEffects::AutomaticAllocationScopeResource>(
2052-
effect.getResource());
2053-
});
2054-
}
2055-
return true;
2034+
return findCapturedOmpOp(*this, [&](Operation *sibling) {
2035+
if (!sibling)
2036+
return false;
2037+
2038+
if (ompDialect == sibling->getDialect())
2039+
return sibling->hasTrait<OpTrait::IsTerminator>();
2040+
2041+
if (auto memOp = dyn_cast<MemoryEffectOpInterface>(sibling)) {
2042+
SmallVector<SideEffects::EffectInstance<MemoryEffects::Effect>, 4>
2043+
effects;
2044+
memOp.getEffects(effects);
2045+
return !llvm::any_of(effects, [&](MemoryEffects::EffectInstance &effect) {
2046+
return isa<MemoryEffects::Write>(effect.getEffect()) &&
2047+
isa<SideEffects::AutomaticAllocationScopeResource>(
2048+
effect.getResource());
20562049
});
2050+
}
2051+
return true;
2052+
});
20572053
}
20582054

20592055
TargetRegionFlags TargetOp::getKernelExecFlags(Operation *capturedOp) {
@@ -2114,33 +2110,23 @@ TargetRegionFlags TargetOp::getKernelExecFlags(Operation *capturedOp) {
21142110
if (isa<LoopOp>(innermostWrapper))
21152111
return TargetRegionFlags::spmd | TargetRegionFlags::trip_count;
21162112

2117-
// Find single immediately nested captured omp.parallel and add spmd flag
2118-
// (generic-spmd case).
2113+
// Add spmd flag if there's a nested omp.parallel (generic-spmd case).
21192114
//
21202115
// TODO: This shouldn't have to be done here, as it is too easy to break.
21212116
// The openmp-opt pass should be updated to be able to promote kernels like
21222117
// this from "Generic" to "Generic-SPMD". However, the use of the
21232118
// `kmpc_distribute_static_loop` family of functions produced by the
21242119
// OMPIRBuilder for these kernels prevents that from working.
2125-
Dialect *ompDialect = targetOp->getDialect();
2126-
Operation *nestedCapture = findCapturedOmpOp(
2127-
capturedOp, /*checkSingleMandatoryExec=*/false,
2128-
[&](Operation *sibling) {
2129-
return sibling && (ompDialect != sibling->getDialect() ||
2130-
sibling->hasTrait<OpTrait::IsTerminator>());
2131-
});
2120+
bool hasParallel = capturedOp
2121+
->walk<WalkOrder::PreOrder>([](ParallelOp) {
2122+
return WalkResult::interrupt();
2123+
})
2124+
.wasInterrupted();
21322125

21332126
TargetRegionFlags result =
21342127
TargetRegionFlags::generic | TargetRegionFlags::trip_count;
21352128

2136-
if (!nestedCapture)
2137-
return result;
2138-
2139-
while (nestedCapture->getParentOp() != capturedOp)
2140-
nestedCapture = nestedCapture->getParentOp();
2141-
2142-
return isa<ParallelOp>(nestedCapture) ? result | TargetRegionFlags::spmd
2143-
: result;
2129+
return hasParallel ? result | TargetRegionFlags::spmd : result;
21442130
}
21452131
// Detect target-parallel-wsloop[-simd].
21462132
else if (isa<WsloopOp>(innermostWrapper)) {

0 commit comments

Comments
 (0)