|
15 | 15 | #include "mlir/Transforms/DialectConversion.h" |
16 | 16 |
|
17 | 17 | #include <memory> |
| 18 | +#include <optional> |
| 19 | +#include <type_traits> |
18 | 20 |
|
19 | 21 | namespace flangomp { |
20 | 22 | #define GEN_PASS_DEF_GENERICLOOPCONVERSIONPASS |
@@ -77,9 +79,6 @@ class GenericLoopConversionPattern |
77 | 79 | if (loopOp.getOrder()) |
78 | 80 | return todo("order"); |
79 | 81 |
|
80 | | - if (!loopOp.getReductionVars().empty()) |
81 | | - return todo("reduction"); |
82 | | - |
83 | 82 | return mlir::success(); |
84 | 83 | } |
85 | 84 |
|
@@ -213,6 +212,7 @@ class GenericLoopConversionPattern |
213 | 212 |
|
214 | 213 | void rewriteToDistrbute(mlir::omp::LoopOp loopOp, |
215 | 214 | mlir::ConversionPatternRewriter &rewriter) const { |
| 215 | + assert(loopOp.getReductionVars().empty()); |
216 | 216 | rewriteToSingleWrapperOp<mlir::omp::DistributeOp, |
217 | 217 | mlir::omp::DistributeOperands>(loopOp, rewriter); |
218 | 218 | } |
@@ -246,6 +246,12 @@ class GenericLoopConversionPattern |
246 | 246 | Fortran::common::openmp::EntryBlockArgs args; |
247 | 247 | args.priv.vars = clauseOps.privateVars; |
248 | 248 |
|
| 249 | + if constexpr (!std::is_same_v<OpOperandsTy, |
| 250 | + mlir::omp::DistributeOperands>) { |
| 251 | + populateReductionClauseOps(loopOp, clauseOps); |
| 252 | + args.reduction.vars = clauseOps.reductionVars; |
| 253 | + } |
| 254 | + |
249 | 255 | auto wrapperOp = rewriter.create<OpTy>(loopOp.getLoc(), clauseOps); |
250 | 256 | mlir::Block *opBlock = genEntryBlock(rewriter, args, wrapperOp.getRegion()); |
251 | 257 |
|
@@ -288,20 +294,51 @@ class GenericLoopConversionPattern |
288 | 294 | rewriter.createBlock(&distributeOp.getRegion()); |
289 | 295 |
|
290 | 296 | mlir::omp::WsloopOperands wsloopClauseOps; |
| 297 | + populateReductionClauseOps(loopOp, wsloopClauseOps); |
| 298 | + Fortran::common::openmp::EntryBlockArgs wsloopArgs; |
| 299 | + wsloopArgs.reduction.vars = wsloopClauseOps.reductionVars; |
| 300 | + |
291 | 301 | auto wsloopOp = |
292 | 302 | rewriter.create<mlir::omp::WsloopOp>(loopOp.getLoc(), wsloopClauseOps); |
293 | 303 | wsloopOp.setComposite(true); |
294 | | - rewriter.createBlock(&wsloopOp.getRegion()); |
| 304 | + mlir::Block *loopBlock = |
| 305 | + genEntryBlock(rewriter, wsloopArgs, wsloopOp.getRegion()); |
295 | 306 |
|
296 | 307 | mlir::IRMapping mapper; |
297 | | - mlir::Block &loopBlock = *loopOp.getRegion().begin(); |
298 | 308 |
|
299 | | - for (auto [loopOpArg, parallelOpArg] : llvm::zip_equal( |
300 | | - loopBlock.getArguments(), parallelBlock->getArguments())) |
| 309 | + auto loopBlockInterface = |
| 310 | + llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(*loopOp); |
| 311 | + |
| 312 | + for (auto [loopOpArg, parallelOpArg] : |
| 313 | + llvm::zip_equal(loopBlockInterface.getPrivateBlockArgs(), |
| 314 | + parallelBlock->getArguments())) |
301 | 315 | mapper.map(loopOpArg, parallelOpArg); |
302 | 316 |
|
| 317 | + for (auto [loopOpArg, wsloopOpArg] : |
| 318 | + llvm::zip_equal(loopBlockInterface.getReductionBlockArgs(), |
| 319 | + loopBlock->getArguments())) |
| 320 | + mapper.map(loopOpArg, wsloopOpArg); |
| 321 | + |
303 | 322 | rewriter.clone(*loopOp.begin(), mapper); |
304 | 323 | } |
| 324 | + |
| 325 | + template <typename OpOperandsTy> |
| 326 | + void populateReductionClauseOps(mlir::omp::LoopOp loopOp, |
| 327 | + OpOperandsTy &clauseOps) const { |
| 328 | + clauseOps.reductionMod = loopOp.getReductionModAttr(); |
| 329 | + clauseOps.reductionVars = loopOp.getReductionVars(); |
| 330 | + |
| 331 | + std::optional<mlir::ArrayAttr> reductionSyms = loopOp.getReductionSyms(); |
| 332 | + if (reductionSyms) |
| 333 | + clauseOps.reductionSyms.assign(reductionSyms->begin(), |
| 334 | + reductionSyms->end()); |
| 335 | + |
| 336 | + std::optional<llvm::ArrayRef<bool>> reductionByref = |
| 337 | + loopOp.getReductionByref(); |
| 338 | + if (reductionByref) |
| 339 | + clauseOps.reductionByref.assign(reductionByref->begin(), |
| 340 | + reductionByref->end()); |
| 341 | + } |
305 | 342 | }; |
306 | 343 |
|
307 | 344 | class GenericLoopConversionPass |
|
0 commit comments