Skip to content

Commit 27a9e19

Browse files
committed
MLIR][LLVMIR] Adding scan lowering to llvm on the mlir side
1 parent 5073733 commit 27a9e19

File tree

5 files changed

+526
-84
lines changed

5 files changed

+526
-84
lines changed

flang/lib/Lower/OpenMP/OpenMP.cpp

Lines changed: 33 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2326,12 +2326,40 @@ genParallelOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
23262326

23272327
static mlir::omp::ScanOp
23282328
genScanOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
2329-
semantics::SemanticsContext &semaCtx, mlir::Location loc,
2330-
const ConstructQueue &queue, ConstructQueue::const_iterator item) {
2329+
semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval,
2330+
mlir::Location loc, const ConstructQueue &queue,
2331+
ConstructQueue::const_iterator item) {
23312332
mlir::omp::ScanOperands clauseOps;
23322333
genScanClauses(converter, semaCtx, item->clauses, loc, clauseOps);
2333-
return mlir::omp::ScanOp::create(converter.getFirOpBuilder(),
2334-
converter.getCurrentLocation(), clauseOps);
2334+
mlir::omp::ScanOp scanOp = mlir::omp::ScanOp::create(
2335+
converter.getFirOpBuilder(), converter.getCurrentLocation(), clauseOps);
2336+
// All loop indices should be loaded after the scan construct as otherwise,
2337+
// it would result in using the index variable across scan directive.
2338+
// (`Intra-iteration dependences from a statement in the structured
2339+
// block sequence that precede a scan directive to a statement in the
2340+
// structured block sequence that follows a scan directive must not exist,
2341+
// except for dependences for the list items specified in an inclusive or
2342+
// exclusive clause.`).
2343+
// TODO: Nested loops are not handled.
2344+
mlir::omp::LoopNestOp loopNestOp =
2345+
scanOp->getParentOfType<mlir::omp::LoopNestOp>();
2346+
assert(loopNestOp.getNumLoops() == 1 &&
2347+
"Scan directive inside nested do loops is not handled yet.");
2348+
mlir::Region &region = loopNestOp->getRegion(0);
2349+
mlir::Value indexVal = fir::getBase(region.getArgument(0));
2350+
lower::pft::Evaluation *doConstructEval = eval.parentConstruct;
2351+
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
2352+
lower::pft::Evaluation *doLoop = &doConstructEval->getFirstNestedEvaluation();
2353+
auto *doStmt = doLoop->getIf<parser::NonLabelDoStmt>();
2354+
assert(doStmt && "Expected do loop to be in the nested evaluation");
2355+
const auto &loopControl =
2356+
std::get<std::optional<parser::LoopControl>>(doStmt->t);
2357+
const parser::LoopControl::Bounds *bounds =
2358+
std::get_if<parser::LoopControl::Bounds>(&loopControl->u);
2359+
mlir::Operation *storeOp =
2360+
setLoopVar(converter, loc, indexVal, bounds->name.thing.symbol);
2361+
firOpBuilder.setInsertionPointAfter(storeOp);
2362+
return scanOp;
23352363
}
23362364

23372365
static mlir::omp::SectionsOp
@@ -3416,7 +3444,7 @@ static void genOMPDispatch(lower::AbstractConverter &converter,
34163444
loc, queue, item);
34173445
break;
34183446
case llvm::omp::Directive::OMPD_scan:
3419-
newOp = genScanOp(converter, symTable, semaCtx, loc, queue, item);
3447+
newOp = genScanOp(converter, symTable, semaCtx, eval, loc, queue, item);
34203448
break;
34213449
case llvm::omp::Directive::OMPD_section:
34223450
llvm_unreachable("genOMPDispatch: OMPD_section");

flang/test/Examples/omp-scan.f90

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
! RUN: %flang_fc1 -fopenmp -emit-obj %s -o %t.o
2+
! RUN: %flang -fopenmp -o %t %t.o
3+
! RUN: %t | FileCheck %s
4+
program inclusive_scan
5+
implicit none
6+
integer, parameter :: n = 100
7+
integer a(n), b(n)
8+
integer x, k, y, z
9+
10+
! initialization
11+
x = 0
12+
do k = 1, n
13+
a(k) = k
14+
end do
15+
16+
! a(k) is included in the computation of producing results in b(k)
17+
!$omp parallel do reduction(inscan, +: x)
18+
do k = 1, n
19+
x = x + a(k)
20+
!$omp scan inclusive(x)
21+
b(k) = x
22+
end do
23+
24+
print *,'x =', x
25+
end program
26+
!CHECK: x=5050

0 commit comments

Comments
 (0)