Skip to content

Commit 2decfef

Browse files
committed
MLIR][LLVMIR] Adding scan lowering to llvm on the mlir side
1 parent 5073733 commit 2decfef

File tree

6 files changed

+606
-84
lines changed

6 files changed

+606
-84
lines changed

flang/lib/Lower/OpenMP/OpenMP.cpp

Lines changed: 45 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2326,12 +2326,52 @@ genParallelOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
23262326

23272327
static mlir::omp::ScanOp
23282328
genScanOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
2329-
semantics::SemanticsContext &semaCtx, mlir::Location loc,
2330-
const ConstructQueue &queue, ConstructQueue::const_iterator item) {
2329+
semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval,
2330+
mlir::Location loc, const ConstructQueue &queue,
2331+
ConstructQueue::const_iterator item) {
23312332
mlir::omp::ScanOperands clauseOps;
23322333
genScanClauses(converter, semaCtx, item->clauses, loc, clauseOps);
2333-
return mlir::omp::ScanOp::create(converter.getFirOpBuilder(),
2334-
converter.getCurrentLocation(), clauseOps);
2334+
mlir::omp::ScanOp scanOp = mlir::omp::ScanOp::create(
2335+
converter.getFirOpBuilder(), converter.getCurrentLocation(), clauseOps);
2336+
2337+
/// Scan redution is not implemented with nested workshare loops, linear
2338+
/// clause, tiling
2339+
mlir::omp::LoopNestOp loopNestOp =
2340+
scanOp->getParentOfType<mlir::omp::LoopNestOp>();
2341+
mlir::omp::WsloopOp wsLoopOp = scanOp->getParentOfType<mlir::omp::WsloopOp>();
2342+
bool isNested =
2343+
(loopNestOp.getNumLoops() > 1) ||
2344+
(wsLoopOp && (wsLoopOp->getParentOfType<mlir::omp::WsloopOp>()));
2345+
if (isNested)
2346+
TODO(loc, "Scan directive inside nested workshare loops");
2347+
if (wsLoopOp && !wsLoopOp.getLinearVars().empty())
2348+
TODO(loc, "Scan directive with linear clause");
2349+
if (loopNestOp.getTileSizes())
2350+
TODO(loc, "Scan directive with loop tiling");
2351+
2352+
// All loop indices should be loaded after the scan construct as otherwise,
2353+
// it would result in using the index variable across scan directive.
2354+
// (`Intra-iteration dependences from a statement in the structured
2355+
// block sequence that precede a scan directive to a statement in the
2356+
// structured block sequence that follows a scan directive must not exist,
2357+
// except for dependences for the list items specified in an inclusive or
2358+
// exclusive clause.`).
2359+
// TODO: Nested loops are not handled.
2360+
mlir::Region &region = loopNestOp->getRegion(0);
2361+
mlir::Value indexVal = fir::getBase(region.getArgument(0));
2362+
lower::pft::Evaluation *doConstructEval = eval.parentConstruct;
2363+
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
2364+
lower::pft::Evaluation *doLoop = &doConstructEval->getFirstNestedEvaluation();
2365+
auto *doStmt = doLoop->getIf<parser::NonLabelDoStmt>();
2366+
assert(doStmt && "Expected do loop to be in the nested evaluation");
2367+
const auto &loopControl =
2368+
std::get<std::optional<parser::LoopControl>>(doStmt->t);
2369+
const parser::LoopControl::Bounds *bounds =
2370+
std::get_if<parser::LoopControl::Bounds>(&loopControl->u);
2371+
mlir::Operation *storeOp =
2372+
setLoopVar(converter, loc, indexVal, bounds->name.thing.symbol);
2373+
firOpBuilder.setInsertionPointAfter(storeOp);
2374+
return scanOp;
23352375
}
23362376

23372377
static mlir::omp::SectionsOp
@@ -3416,7 +3456,7 @@ static void genOMPDispatch(lower::AbstractConverter &converter,
34163456
loc, queue, item);
34173457
break;
34183458
case llvm::omp::Directive::OMPD_scan:
3419-
newOp = genScanOp(converter, symTable, semaCtx, loc, queue, item);
3459+
newOp = genScanOp(converter, symTable, semaCtx, eval, loc, queue, item);
34203460
break;
34213461
case llvm::omp::Directive::OMPD_section:
34223462
llvm_unreachable("genOMPDispatch: OMPD_section");
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
! Tests scan reduction behavior when used in nested workshare loops
2+
3+
! RUN: %not_todo_cmd %flang_fc1 -emit-hlfir -fopenmp -o - %s 2>&1 | FileCheck %s
4+
5+
program nested_scan_example
6+
implicit none
7+
integer, parameter :: n = 4, m = 5
8+
integer :: a(n, m), b(n, m)
9+
integer :: i, j
10+
integer :: row_sum, col_sum
11+
12+
do i = 1, n
13+
do j = 1, m
14+
a(i, j) = i + j
15+
end do
16+
end do
17+
18+
!$omp parallel do reduction(inscan, +: row_sum) private(col_sum, j)
19+
do i = 1, n
20+
row_sum = row_sum + i
21+
!$omp scan inclusive(row_sum)
22+
23+
col_sum = 0
24+
!$omp parallel do reduction(inscan, +: col_sum)
25+
do j = 1, m
26+
col_sum = col_sum + a(i, j)
27+
!CHECK: not yet implemented: Scan directive inside nested workshare loops
28+
!$omp scan inclusive(col_sum)
29+
b(i, j) = col_sum + row_sum
30+
end do
31+
!$omp end parallel do
32+
end do
33+
!$omp end parallel do
34+
end program nested_scan_example
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
! Tests scan reduction behavior when used in nested workshare loops
2+
3+
! RUN: %not_todo_cmd %flang_fc1 -emit-hlfir -fopenmp -o - %s 2>&1 | FileCheck %s
4+
5+
program nested_loop_example
6+
implicit none
7+
integer :: i, j, x
8+
integer, parameter :: N = 100, M = 200
9+
real :: A(N, M), B(N, M)
10+
x = 0
11+
12+
do i = 1, N
13+
do j = 1, M
14+
A(i, j) = i * j
15+
end do
16+
end do
17+
18+
!$omp parallel do collapse(2) reduction(inscan, +:x)
19+
do i = 1, N
20+
do j = 1, M
21+
x = x + A(i,j)
22+
!CHECK: not yet implemented: Scan directive inside nested workshare loops
23+
!$omp scan inclusive(x)
24+
B(i, j) = x
25+
end do
26+
end do
27+
!$omp end parallel do
28+
29+
end program nested_loop_example

0 commit comments

Comments
 (0)