Skip to content

Commit 9fb0bde

Browse files
committed
Fix target teams distribute + parallel case
1 parent 9a18307 commit 9fb0bde

File tree

1 file changed

+10
-18
lines changed

1 file changed

+10
-18
lines changed

mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp

Lines changed: 10 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2109,10 +2109,9 @@ TargetRegionFlags TargetOp::getKernelExecFlags(Operation *capturedOp) {
21092109
if (isa<LoopOp>(innermostWrapper))
21102110
return TargetRegionFlags::spmd | TargetRegionFlags::trip_count;
21112111

2112-
TargetRegionFlags result =
2113-
TargetRegionFlags::generic | TargetRegionFlags::trip_count;
2114-
2115-
// Find single nested parallel-do and add spmd flag (generic-spmd case).
2112+
// Find single immediately nested captured omp.parallel and add spmd flag
2113+
// (generic-spmd case).
2114+
//
21162115
// TODO: This shouldn't have to be done here, as it is too easy to break.
21172116
// The openmp-opt pass should be updated to be able to promote kernels like
21182117
// this from "Generic" to "Generic-SPMD". However, the use of the
@@ -2125,24 +2124,17 @@ TargetRegionFlags TargetOp::getKernelExecFlags(Operation *capturedOp) {
21252124
sibling->hasTrait<OpTrait::IsTerminator>());
21262125
});
21272126

2128-
if (!isa_and_present<LoopNestOp>(nestedCapture))
2129-
return result;
2130-
2131-
int numNestedWrappers;
2132-
LoopWrapperInterface *nestedWrapper =
2133-
getInnermostWrapper(cast<LoopNestOp>(nestedCapture), numNestedWrappers);
2134-
2135-
if (numNestedWrappers != 1 || !isa<WsloopOp>(nestedWrapper))
2136-
return result;
2127+
TargetRegionFlags result =
2128+
TargetRegionFlags::generic | TargetRegionFlags::trip_count;
21372129

2138-
Operation *parallelOp = (*nestedWrapper)->getParentOp();
2139-
if (!isa_and_present<ParallelOp>(parallelOp))
2130+
if (!nestedCapture)
21402131
return result;
21412132

2142-
if (parallelOp->getParentOp() != capturedOp)
2143-
return result;
2133+
while (nestedCapture->getParentOp() != capturedOp)
2134+
nestedCapture = nestedCapture->getParentOp();
21442135

2145-
return result | TargetRegionFlags::spmd;
2136+
return isa<ParallelOp>(nestedCapture) ? result | TargetRegionFlags::spmd
2137+
: result;
21462138
}
21472139
// Detect target-parallel-wsloop[-simd].
21482140
else if (isa<WsloopOp>(innermostWrapper)) {

0 commit comments

Comments
 (0)