Skip to content

Commit ad88725

Browse files
committed
[MLIR][LLVMIR] Adding scan lowering to llvm on the mlir side
1 parent 5073733 commit ad88725

File tree

4 files changed

+503
-84
lines changed

4 files changed

+503
-84
lines changed

flang/lib/Lower/OpenMP/OpenMP.cpp

Lines changed: 34 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2326,12 +2326,41 @@ genParallelOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
23262326

23272327
static mlir::omp::ScanOp
23282328
genScanOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
2329-
semantics::SemanticsContext &semaCtx, mlir::Location loc,
2330-
const ConstructQueue &queue, ConstructQueue::const_iterator item) {
2329+
semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval,
2330+
mlir::Location loc, const ConstructQueue &queue,
2331+
ConstructQueue::const_iterator item) {
23312332
mlir::omp::ScanOperands clauseOps;
23322333
genScanClauses(converter, semaCtx, item->clauses, loc, clauseOps);
2333-
return mlir::omp::ScanOp::create(converter.getFirOpBuilder(),
2334-
converter.getCurrentLocation(), clauseOps);
2334+
mlir::omp::ScanOp scanOp = mlir::omp::ScanOp::create(
2335+
converter.getFirOpBuilder(), converter.getCurrentLocation(), clauseOps);
2336+
// If there are nested loops all indices should be loaded after
2337+
// the scan construct as otherwise, it would result in using the index
2338+
// variable across scan directive.
2339+
// (`Intra-iteration dependences from a statement in the structured
2340+
// block sequence that precede a scan directive to a statement in the
2341+
// structured block sequence that follows a scan directive must not exist,
2342+
// except for dependences for the list items specified in an inclusive or
2343+
// exclusive clause.`).
2344+
// TODO: If there are nested loops, it is not handled.
2345+
mlir::omp::LoopNestOp loopNestOp =
2346+
scanOp->getParentOfType<mlir::omp::LoopNestOp>();
2347+
assert(loopNestOp.getNumLoops() == 1 &&
2348+
"Scan directive inside nested do loops is not handled yet.");
2349+
mlir::Region &region = loopNestOp->getRegion(0);
2350+
mlir::Value indexVal = fir::getBase(region.getArgument(0));
2351+
lower::pft::Evaluation *doConstructEval = eval.parentConstruct;
2352+
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
2353+
lower::pft::Evaluation *doLoop = &doConstructEval->getFirstNestedEvaluation();
2354+
auto *doStmt = doLoop->getIf<parser::NonLabelDoStmt>();
2355+
assert(doStmt && "Expected do loop to be in the nested evaluation");
2356+
const auto &loopControl =
2357+
std::get<std::optional<parser::LoopControl>>(doStmt->t);
2358+
const parser::LoopControl::Bounds *bounds =
2359+
std::get_if<parser::LoopControl::Bounds>(&loopControl->u);
2360+
mlir::Operation *storeOp =
2361+
setLoopVar(converter, loc, indexVal, bounds->name.thing.symbol);
2362+
firOpBuilder.setInsertionPointAfter(storeOp);
2363+
return scanOp;
23352364
}
23362365

23372366
static mlir::omp::SectionsOp
@@ -3416,7 +3445,7 @@ static void genOMPDispatch(lower::AbstractConverter &converter,
34163445
loc, queue, item);
34173446
break;
34183447
case llvm::omp::Directive::OMPD_scan:
3419-
newOp = genScanOp(converter, symTable, semaCtx, loc, queue, item);
3448+
newOp = genScanOp(converter, symTable, semaCtx, eval, loc, queue, item);
34203449
break;
34213450
case llvm::omp::Directive::OMPD_section:
34223451
llvm_unreachable("genOMPDispatch: OMPD_section");

0 commit comments

Comments
 (0)