1515#include " mlir/Transforms/DialectConversion.h"
1616
1717#include < memory>
18+ #include < optional>
19+ #include < type_traits>
1820
1921namespace flangomp {
2022#define GEN_PASS_DEF_GENERICLOOPCONVERSIONPASS
@@ -58,7 +60,7 @@ class GenericLoopConversionPattern
5860 if (teamsLoopCanBeParallelFor (loopOp))
5961 rewriteToDistributeParallelDo (loopOp, rewriter);
6062 else
61- rewriteToDistrbute (loopOp, rewriter);
63+ rewriteToDistribute (loopOp, rewriter);
6264 break ;
6365 }
6466
@@ -77,9 +79,6 @@ class GenericLoopConversionPattern
7779 if (loopOp.getOrder ())
7880 return todo (" order" );
7981
80- if (!loopOp.getReductionVars ().empty ())
81- return todo (" reduction" );
82-
8382 return mlir::success ();
8483 }
8584
@@ -168,7 +167,7 @@ class GenericLoopConversionPattern
168167 case ClauseBindKind::Parallel:
169168 return rewriteToWsloop (loopOp, rewriter);
170169 case ClauseBindKind::Teams:
171- return rewriteToDistrbute (loopOp, rewriter);
170+ return rewriteToDistribute (loopOp, rewriter);
172171 case ClauseBindKind::Thread:
173172 return rewriteToSimdLoop (loopOp, rewriter);
174173 }
@@ -211,8 +210,9 @@ class GenericLoopConversionPattern
211210 loopOp, rewriter);
212211 }
213212
214- void rewriteToDistrbute (mlir::omp::LoopOp loopOp,
215- mlir::ConversionPatternRewriter &rewriter) const {
213+ void rewriteToDistribute (mlir::omp::LoopOp loopOp,
214+ mlir::ConversionPatternRewriter &rewriter) const {
215+ assert (loopOp.getReductionVars ().empty ());
216216 rewriteToSingleWrapperOp<mlir::omp::DistributeOp,
217217 mlir::omp::DistributeOperands>(loopOp, rewriter);
218218 }
@@ -246,6 +246,12 @@ class GenericLoopConversionPattern
246246 Fortran::common::openmp::EntryBlockArgs args;
247247 args.priv .vars = clauseOps.privateVars ;
248248
249+ if constexpr (!std::is_same_v<OpOperandsTy,
250+ mlir::omp::DistributeOperands>) {
251+ populateReductionClauseOps (loopOp, clauseOps);
252+ args.reduction .vars = clauseOps.reductionVars ;
253+ }
254+
249255 auto wrapperOp = rewriter.create <OpTy>(loopOp.getLoc (), clauseOps);
250256 mlir::Block *opBlock = genEntryBlock (rewriter, args, wrapperOp.getRegion ());
251257
@@ -275,8 +281,7 @@ class GenericLoopConversionPattern
275281
276282 auto parallelOp = rewriter.create <mlir::omp::ParallelOp>(loopOp.getLoc (),
277283 parallelClauseOps);
278- mlir::Block *parallelBlock =
279- genEntryBlock (rewriter, parallelArgs, parallelOp.getRegion ());
284+ genEntryBlock (rewriter, parallelArgs, parallelOp.getRegion ());
280285 parallelOp.setComposite (true );
281286 rewriter.setInsertionPoint (
282287 rewriter.create <mlir::omp::TerminatorOp>(loopOp.getLoc ()));
@@ -288,20 +293,54 @@ class GenericLoopConversionPattern
288293 rewriter.createBlock (&distributeOp.getRegion ());
289294
290295 mlir::omp::WsloopOperands wsloopClauseOps;
296+ populateReductionClauseOps (loopOp, wsloopClauseOps);
297+ Fortran::common::openmp::EntryBlockArgs wsloopArgs;
298+ wsloopArgs.reduction .vars = wsloopClauseOps.reductionVars ;
299+
291300 auto wsloopOp =
292301 rewriter.create <mlir::omp::WsloopOp>(loopOp.getLoc (), wsloopClauseOps);
293302 wsloopOp.setComposite (true );
294- rewriter. createBlock (& wsloopOp.getRegion ());
303+ genEntryBlock (rewriter, wsloopArgs, wsloopOp.getRegion ());
295304
296305 mlir::IRMapping mapper;
297- mlir::Block &loopBlock = *loopOp.getRegion ().begin ();
298306
299- for (auto [loopOpArg, parallelOpArg] : llvm::zip_equal (
300- loopBlock.getArguments (), parallelBlock->getArguments ()))
307+ auto loopBlockInterface =
308+ llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(*loopOp);
309+ auto parallelBlockInterface =
310+ llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(*parallelOp);
311+ auto wsloopBlockInterface =
312+ llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(*wsloopOp);
313+
314+ for (auto [loopOpArg, parallelOpArg] :
315+ llvm::zip_equal (loopBlockInterface.getPrivateBlockArgs (),
316+ parallelBlockInterface.getPrivateBlockArgs ()))
301317 mapper.map (loopOpArg, parallelOpArg);
302318
319+ for (auto [loopOpArg, wsloopOpArg] :
320+ llvm::zip_equal (loopBlockInterface.getReductionBlockArgs (),
321+ wsloopBlockInterface.getReductionBlockArgs ()))
322+ mapper.map (loopOpArg, wsloopOpArg);
323+
303324 rewriter.clone (*loopOp.begin (), mapper);
304325 }
326+
327+ void
328+ populateReductionClauseOps (mlir::omp::LoopOp loopOp,
329+ mlir::omp::ReductionClauseOps &clauseOps) const {
330+ clauseOps.reductionMod = loopOp.getReductionModAttr ();
331+ clauseOps.reductionVars = loopOp.getReductionVars ();
332+
333+ std::optional<mlir::ArrayAttr> reductionSyms = loopOp.getReductionSyms ();
334+ if (reductionSyms)
335+ clauseOps.reductionSyms .assign (reductionSyms->begin (),
336+ reductionSyms->end ());
337+
338+ std::optional<llvm::ArrayRef<bool >> reductionByref =
339+ loopOp.getReductionByref ();
340+ if (reductionByref)
341+ clauseOps.reductionByref .assign (reductionByref->begin (),
342+ reductionByref->end ());
343+ }
305344};
306345
307346class GenericLoopConversionPass
0 commit comments