Skip to content

Commit bff6b92

Browse files
authored
[flang][OpenMP] Map teams loop to teams distribute when required. (#127489)
This extends support for generic `loop` rewriting by: 1. Preventing nesting multiple worksharing loops inside each other. This is checked by walking the `teams loop` region searching for any `loop` directive whose `bind` modifier is `parallel`. 2. Preventing convert to worksharing loop if calls to unknow functions are found in the `loop` directive's body. We walk the `teams loop` body to identify either of the above 2 conditions, if either of them is found to be true, we map the `loop` directive to `distribute`.
1 parent 4d6167e commit bff6b92

File tree

2 files changed

+144
-5
lines changed

2 files changed

+144
-5
lines changed

flang/lib/Optimizer/OpenMP/GenericLoopConversion.cpp

Lines changed: 60 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,10 @@ class GenericLoopConversionPattern
5656
"not yet implemented: Combined `parallel loop` directive");
5757
break;
5858
case GenericLoopCombinedInfo::TeamsLoop:
59-
rewriteToDistributeParallelDo(loopOp, rewriter);
59+
if (teamsLoopCanBeParallelFor(loopOp))
60+
rewriteToDistributeParallelDo(loopOp, rewriter);
61+
else
62+
rewriteToDistrbute(loopOp, rewriter);
6063
break;
6164
}
6265

@@ -97,8 +100,6 @@ class GenericLoopConversionPattern
97100
if (!loopOp.getReductionVars().empty())
98101
return todo("reduction");
99102

100-
// TODO For `teams loop`, check similar constrains to what is checked
101-
// by `TeamsLoopChecker` in SemaOpenMP.cpp.
102103
return mlir::success();
103104
}
104105

@@ -118,6 +119,62 @@ class GenericLoopConversionPattern
118119
return result;
119120
}
120121

122+
/// Checks whether a `teams loop` construct can be rewriten to `teams
123+
/// distribute parallel do` or it has to be converted to `teams distribute`.
124+
///
125+
/// This checks similar constrains to what is checked by `TeamsLoopChecker` in
126+
/// SemaOpenMP.cpp in clang.
127+
static bool teamsLoopCanBeParallelFor(mlir::omp::LoopOp loopOp) {
128+
bool canBeParallelFor =
129+
!loopOp
130+
.walk<mlir::WalkOrder::PreOrder>([&](mlir::Operation *nestedOp) {
131+
if (nestedOp == loopOp)
132+
return mlir::WalkResult::advance();
133+
134+
if (auto nestedLoopOp =
135+
mlir::dyn_cast<mlir::omp::LoopOp>(nestedOp)) {
136+
GenericLoopCombinedInfo combinedInfo =
137+
findGenericLoopCombineInfo(nestedLoopOp);
138+
139+
// Worksharing loops cannot be nested inside each other.
140+
// Therefore, if the current `loop` directive nests another
141+
// `loop` whose `bind` modifier is `parallel`, this `loop`
142+
// directive cannot be mapped to `distribute parallel for`
143+
// but rather only to `distribute`.
144+
if (combinedInfo == GenericLoopCombinedInfo::Standalone &&
145+
nestedLoopOp.getBindKind() &&
146+
*nestedLoopOp.getBindKind() ==
147+
mlir::omp::ClauseBindKind::Parallel)
148+
return mlir::WalkResult::interrupt();
149+
150+
// TODO check for combined `parallel loop` when we support
151+
// it.
152+
} else if (auto callOp =
153+
mlir::dyn_cast<mlir::CallOpInterface>(nestedOp)) {
154+
// Calls to non-OpenMP API runtime functions inhibits
155+
// transformation to `teams distribute parallel do` since the
156+
// called functions might have nested parallelism themselves.
157+
bool isOpenMPAPI = false;
158+
mlir::CallInterfaceCallable callable =
159+
callOp.getCallableForCallee();
160+
161+
if (auto callableSymRef =
162+
mlir::dyn_cast<mlir::SymbolRefAttr>(callable))
163+
isOpenMPAPI =
164+
callableSymRef.getRootReference().strref().starts_with(
165+
"omp_");
166+
167+
if (!isOpenMPAPI)
168+
return mlir::WalkResult::interrupt();
169+
}
170+
171+
return mlir::WalkResult::advance();
172+
})
173+
.wasInterrupted();
174+
175+
return canBeParallelFor;
176+
}
177+
121178
void rewriteStandaloneLoop(mlir::omp::LoopOp loopOp,
122179
mlir::ConversionPatternRewriter &rewriter) const {
123180
using namespace mlir::omp;

flang/test/Lower/OpenMP/loop-directive.f90

Lines changed: 84 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
! This test checks lowering of OpenMP loop Directive.
22

3-
! RUN: bbc -emit-hlfir -fopenmp -fopenmp-version=50 -o - %s 2>&1 | FileCheck %s
4-
! RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=50 -o - %s 2>&1 | FileCheck %s
3+
! REQUIRES: openmp_runtime
4+
5+
! RUN: %flang_fc1 -emit-hlfir %openmp_flags -fopenmp-version=50 -o - %s 2>&1 | FileCheck %s
56

67
! CHECK: omp.declare_reduction @[[RED:add_reduction_i32]] : i32
78
! CHECK: omp.private {type = private} @[[DUMMY_PRIV:.*test_privateEdummy_private.*]] : i32
@@ -179,3 +180,84 @@ subroutine test_standalone_bind_parallel
179180
c(i) = a(i) * b(i)
180181
end do
181182
end subroutine
183+
184+
! CHECK-LABEL: func.func @_QPteams_loop_cannot_be_parallel_for
185+
subroutine teams_loop_cannot_be_parallel_for
186+
implicit none
187+
integer :: iter, iter2, val(20)
188+
val = 0
189+
! CHECK: omp.teams {
190+
191+
! Verify the outer `loop` directive was mapped to only `distribute`.
192+
! CHECK-NOT: omp.parallel {{.*}}
193+
! CHECK: omp.distribute {{.*}} {
194+
! CHECK-NEXT: omp.loop_nest {{.*}} {
195+
196+
! Verify the inner `loop` directive was mapped to a worksharing loop.
197+
! CHECK: omp.wsloop {{.*}} {
198+
! CHECK-NEXT: omp.loop_nest {{.*}} {
199+
! CHECK: }
200+
! CHECK: }
201+
202+
! CHECK: }
203+
! CHECK: }
204+
205+
! CHECK: }
206+
!$omp target teams loop map(tofrom:val)
207+
DO iter = 1, 5
208+
!$omp loop bind(parallel)
209+
DO iter2 = 1, 5
210+
val(iter+iter2) = iter+iter2
211+
END DO
212+
END DO
213+
end subroutine
214+
215+
subroutine foo()
216+
end subroutine
217+
218+
! CHECK-LABEL: func.func @_QPteams_loop_cannot_be_parallel_for_2
219+
subroutine teams_loop_cannot_be_parallel_for_2
220+
implicit none
221+
integer :: iter, val(20)
222+
val = 0
223+
224+
! CHECK: omp.teams {
225+
226+
! Verify the `loop` directive was mapped to only `distribute`.
227+
! CHECK-NOT: omp.parallel {{.*}}
228+
! CHECK: omp.distribute {{.*}} {
229+
! CHECK-NEXT: omp.loop_nest {{.*}} {
230+
! CHECK: fir.call @_QPfoo
231+
! CHECK: }
232+
! CHECK: }
233+
234+
! CHECK: }
235+
!$omp target teams loop map(tofrom:val)
236+
DO iter = 1, 5
237+
call foo()
238+
END DO
239+
end subroutine
240+
241+
! CHECK-LABEL: func.func @_QPteams_loop_can_be_parallel_for
242+
subroutine teams_loop_can_be_parallel_for
243+
use omp_lib
244+
implicit none
245+
integer :: iter, tid, val(20)
246+
val = 0
247+
248+
!CHECK: omp.teams {
249+
!CHECK: omp.parallel {{.*}} {
250+
!CHECK: omp.distribute {
251+
!CHECK: omp.wsloop {
252+
!CHECK: omp.loop_nest {{.*}} {
253+
!CHECK: fir.call @omp_get_thread_num()
254+
!CHECK: }
255+
!CHECK: }
256+
!CHECK: }
257+
!CHECK: }
258+
!CHECK: }
259+
!$omp target teams loop map(tofrom:val)
260+
DO iter = 1, 5
261+
tid = omp_get_thread_num()
262+
END DO
263+
end subroutine

0 commit comments

Comments
 (0)