1111#include " flang/Optimizer/OpenMP/Utils.h"
1212#include " mlir/Analysis/SliceAnalysis.h"
1313#include " mlir/Dialect/OpenMP/OpenMPDialect.h"
14+ #include " mlir/IR/IRMapping.h"
1415#include " mlir/Transforms/DialectConversion.h"
1516#include " mlir/Transforms/RegionUtils.h"
1617
@@ -24,7 +25,82 @@ namespace flangomp {
2425
2526namespace {
2627namespace looputils {
27- using LoopNest = llvm::SetVector<fir::DoLoopOp>;
28+ // / Stores info needed about the induction/iteration variable for each `do
29+ // / concurrent` in a loop nest. This includes only for now:
30+ // / * the operation allocating memory for iteration variable,
31+ struct InductionVariableInfo {
32+ mlir::Operation *iterVarMemDef;
33+ };
34+
35+ using LoopNestToIndVarMap =
36+ llvm::MapVector<fir::DoLoopOp, InductionVariableInfo>;
37+
38+ // / Given an operation `op`, this returns true if one of `op`'s operands is
39+ // / "ultimately" the loop's induction variable. This helps in cases where the
40+ // / induction variable's use is "hidden" behind a convert/cast.
41+ // /
42+ // / For example, give the following loop:
43+ // / ```
44+ // / fir.do_loop %ind_var = %lb to %ub step %s unordered {
45+ // / %ind_var_conv = fir.convert %ind_var : (index) -> i32
46+ // / fir.store %ind_var_conv to %i#1 : !fir.ref<i32>
47+ // / ...
48+ // / }
49+ // / ```
50+ // /
51+ // / If \p op is the `fir.store` operation, then this function will return true
52+ // / since the IV is the "ultimate" opeerand to the `fir.store` op through the
53+ // / `%ind_var_conv` -> `%ind_var` conversion sequence.
54+ // /
55+ // / For why this is useful, see its use in `findLoopIndVarMemDecl`.
56+ bool isIndVarUltimateOperand (mlir::Operation *op, fir::DoLoopOp doLoop) {
57+ while (op != nullptr && op->getNumOperands () > 0 ) {
58+ auto ivIt = llvm::find_if (op->getOperands (), [&](mlir::Value operand) {
59+ return operand == doLoop.getInductionVar ();
60+ });
61+
62+ if (ivIt != op->getOperands ().end ())
63+ return true ;
64+
65+ op = op->getOperand (0 ).getDefiningOp ();
66+ }
67+
68+ return false ;
69+ }
70+
71+ // / For the \p doLoop parameter, find the operation that declares its iteration
72+ // / variable or allocates memory for it.
73+ // /
74+ // / For example, give the following loop:
75+ // / ```
76+ // / ...
77+ // / %i:2 = hlfir.declare %0 {uniq_name = "_QFEi"} : ...
78+ // / ...
79+ // / fir.do_loop %ind_var = %lb to %ub step %s unordered {
80+ // / %ind_var_conv = fir.convert %ind_var : (index) -> i32
81+ // / fir.store %ind_var_conv to %i#1 : !fir.ref<i32>
82+ // / ...
83+ // / }
84+ // / ```
85+ // /
86+ // / This function returns the `hlfir.declare` op for `%i`.
87+ mlir::Operation *findLoopIterationVarMemDecl (fir::DoLoopOp doLoop) {
88+ mlir::Value result = nullptr ;
89+ mlir::visitUsedValuesDefinedAbove (
90+ doLoop.getRegion (), [&](mlir::OpOperand *operand) {
91+ if (result)
92+ return ;
93+
94+ if (isIndVarUltimateOperand (operand->getOwner (), doLoop)) {
95+ assert (result == nullptr &&
96+ " loop can have only one induction variable" );
97+ result = operand->get ();
98+ }
99+ });
100+
101+ assert (result != nullptr && result.getDefiningOp () != nullptr );
102+ return result.getDefiningOp ();
103+ }
28104
29105// / Loop \p innerLoop is considered perfectly-nested inside \p outerLoop iff
30106// / there are no operations in \p outerloop's body other than:
@@ -93,11 +169,14 @@ bool isPerfectlyNested(fir::DoLoopOp outerLoop, fir::DoLoopOp innerLoop) {
93169// / recognize a certain nested loop as part of the nest it just returns the
94170// / parent loops it discovered before.
95171mlir::LogicalResult collectLoopNest (fir::DoLoopOp currentLoop,
96- LoopNest &loopNest) {
172+ LoopNestToIndVarMap &loopNest) {
97173 assert (currentLoop.getUnordered ());
98174
99175 while (true ) {
100- loopNest.insert (currentLoop);
176+ loopNest.try_emplace (
177+ currentLoop,
178+ InductionVariableInfo{findLoopIterationVarMemDecl (currentLoop)});
179+
101180 auto directlyNestedLoops = currentLoop.getRegion ().getOps <fir::DoLoopOp>();
102181 llvm::SmallVector<fir::DoLoopOp> unorderedLoops;
103182
@@ -127,26 +206,136 @@ class DoConcurrentConversion : public mlir::OpConversionPattern<fir::DoLoopOp> {
127206public:
128207 using mlir::OpConversionPattern<fir::DoLoopOp>::OpConversionPattern;
129208
130- DoConcurrentConversion (mlir::MLIRContext *context, bool mapToDevice)
131- : OpConversionPattern(context), mapToDevice(mapToDevice) {}
209+ DoConcurrentConversion (mlir::MLIRContext *context, bool mapToDevice,
210+ llvm::DenseSet<fir::DoLoopOp> &concurrentLoopsToSkip)
211+ : OpConversionPattern(context), mapToDevice(mapToDevice),
212+ concurrentLoopsToSkip (concurrentLoopsToSkip) {}
132213
133214 mlir::LogicalResult
134215 matchAndRewrite (fir::DoLoopOp doLoop, OpAdaptor adaptor,
135216 mlir::ConversionPatternRewriter &rewriter) const override {
136- looputils::LoopNest loopNest;
217+ looputils::LoopNestToIndVarMap loopNest;
137218 bool hasRemainingNestedLoops =
138219 failed (looputils::collectLoopNest (doLoop, loopNest));
139220 if (hasRemainingNestedLoops)
140221 mlir::emitWarning (doLoop.getLoc (),
141222 " Some `do concurent` loops are not perfectly-nested. "
142223 " These will be serialzied." );
143224
144- // TODO This will be filled in with the next PRs that upstreams the rest of
145- // the ROCm implementaion.
225+ mlir::IRMapping mapper;
226+ genParallelOp (doLoop.getLoc (), rewriter, loopNest, mapper);
227+ mlir::omp::LoopNestOperands loopNestClauseOps;
228+ genLoopNestClauseOps (doLoop.getLoc (), rewriter, loopNest, mapper,
229+ loopNestClauseOps);
230+
231+ mlir::omp::LoopNestOp ompLoopNest =
232+ genWsLoopOp (rewriter, loopNest.back ().first , mapper, loopNestClauseOps,
233+ /* isComposite=*/ mapToDevice);
234+
235+ rewriter.eraseOp (doLoop);
236+
237+ // Mark `unordered` loops that are not perfectly nested to be skipped from
238+ // the legality check of the `ConversionTarget` since we are not interested
239+ // in mapping them to OpenMP.
240+ ompLoopNest->walk ([&](fir::DoLoopOp doLoop) {
241+ if (doLoop.getUnordered ()) {
242+ concurrentLoopsToSkip.insert (doLoop);
243+ }
244+ });
245+
146246 return mlir::success ();
147247 }
148248
249+ private:
250+ mlir::omp::ParallelOp genParallelOp (mlir::Location loc,
251+ mlir::ConversionPatternRewriter &rewriter,
252+ looputils::LoopNestToIndVarMap &loopNest,
253+ mlir::IRMapping &mapper) const {
254+ auto parallelOp = rewriter.create <mlir::omp::ParallelOp>(loc);
255+ rewriter.createBlock (¶llelOp.getRegion ());
256+ rewriter.setInsertionPoint (rewriter.create <mlir::omp::TerminatorOp>(loc));
257+
258+ genLoopNestIndVarAllocs (rewriter, loopNest, mapper);
259+ return parallelOp;
260+ }
261+
262+ void genLoopNestIndVarAllocs (mlir::ConversionPatternRewriter &rewriter,
263+ looputils::LoopNestToIndVarMap &loopNest,
264+ mlir::IRMapping &mapper) const {
265+
266+ for (auto &[_, indVarInfo] : loopNest)
267+ genInductionVariableAlloc (rewriter, indVarInfo.iterVarMemDef , mapper);
268+ }
269+
270+ mlir::Operation *
271+ genInductionVariableAlloc (mlir::ConversionPatternRewriter &rewriter,
272+ mlir::Operation *indVarMemDef,
273+ mlir::IRMapping &mapper) const {
274+ assert (
275+ indVarMemDef != nullptr &&
276+ " Induction variable memdef is expected to have a defining operation." );
277+
278+ llvm::SmallSetVector<mlir::Operation *, 2 > indVarDeclareAndAlloc;
279+ for (auto operand : indVarMemDef->getOperands ())
280+ indVarDeclareAndAlloc.insert (operand.getDefiningOp ());
281+ indVarDeclareAndAlloc.insert (indVarMemDef);
282+
283+ mlir::Operation *result;
284+ for (mlir::Operation *opToClone : indVarDeclareAndAlloc)
285+ result = rewriter.clone (*opToClone, mapper);
286+
287+ return result;
288+ }
289+
290+ void genLoopNestClauseOps (
291+ mlir::Location loc, mlir::ConversionPatternRewriter &rewriter,
292+ looputils::LoopNestToIndVarMap &loopNest, mlir::IRMapping &mapper,
293+ mlir::omp::LoopNestOperands &loopNestClauseOps) const {
294+ assert (loopNestClauseOps.loopLowerBounds .empty () &&
295+ " Loop nest bounds were already emitted!" );
296+
297+ auto populateBounds = [&](mlir::Value var,
298+ llvm::SmallVectorImpl<mlir::Value> &bounds) {
299+ bounds.push_back (var.getDefiningOp ()->getResult (0 ));
300+ };
301+
302+ for (auto &[doLoop, _] : loopNest) {
303+ populateBounds (doLoop.getLowerBound (), loopNestClauseOps.loopLowerBounds );
304+ populateBounds (doLoop.getUpperBound (), loopNestClauseOps.loopUpperBounds );
305+ populateBounds (doLoop.getStep (), loopNestClauseOps.loopSteps );
306+ }
307+
308+ loopNestClauseOps.loopInclusive = rewriter.getUnitAttr ();
309+ }
310+
311+ mlir::omp::LoopNestOp
312+ genWsLoopOp (mlir::ConversionPatternRewriter &rewriter, fir::DoLoopOp doLoop,
313+ mlir::IRMapping &mapper,
314+ const mlir::omp::LoopNestOperands &clauseOps,
315+ bool isComposite) const {
316+
317+ auto wsloopOp = rewriter.create <mlir::omp::WsloopOp>(doLoop.getLoc ());
318+ wsloopOp.setComposite (isComposite);
319+ rewriter.createBlock (&wsloopOp.getRegion ());
320+
321+ auto loopNestOp =
322+ rewriter.create <mlir::omp::LoopNestOp>(doLoop.getLoc (), clauseOps);
323+
324+ // Clone the loop's body inside the loop nest construct using the
325+ // mapped values.
326+ rewriter.cloneRegionBefore (doLoop.getRegion (), loopNestOp.getRegion (),
327+ loopNestOp.getRegion ().begin (), mapper);
328+
329+ mlir::Operation *terminator = loopNestOp.getRegion ().back ().getTerminator ();
330+ rewriter.setInsertionPointToEnd (&loopNestOp.getRegion ().back ());
331+ rewriter.create <mlir::omp::YieldOp>(terminator->getLoc ());
332+ rewriter.eraseOp (terminator);
333+
334+ return loopNestOp;
335+ }
336+
149337 bool mapToDevice;
338+ llvm::DenseSet<fir::DoLoopOp> &concurrentLoopsToSkip;
150339};
151340
152341class DoConcurrentConversionPass
@@ -175,16 +364,18 @@ class DoConcurrentConversionPass
175364 return ;
176365 }
177366
367+ llvm::DenseSet<fir::DoLoopOp> concurrentLoopsToSkip;
178368 mlir::RewritePatternSet patterns (context);
179369 patterns.insert <DoConcurrentConversion>(
180- context, mapTo == flangomp::DoConcurrentMappingKind::DCMK_Device);
370+ context, mapTo == flangomp::DoConcurrentMappingKind::DCMK_Device,
371+ concurrentLoopsToSkip);
181372 mlir::ConversionTarget target (*context);
182373 target.addDynamicallyLegalOp <fir::DoLoopOp>([&](fir::DoLoopOp op) {
183374 // The goal is to handle constructs that eventually get lowered to
184375 // `fir.do_loop` with the `unordered` attribute (e.g. array expressions).
185376 // Currently, this is only enabled for the `do concurrent` construct since
186377 // the pass runs early in the pipeline.
187- return !op.getUnordered ();
378+ return !op.getUnordered () || concurrentLoopsToSkip. contains (op) ;
188379 });
189380 target.markUnknownOpDynamicallyLegal (
190381 [](mlir::Operation *) { return true ; });
0 commit comments