@@ -343,6 +343,115 @@ class GenericLoopConversionPattern
343343 }
344344};
345345
346+ // / According to the spec (v5.2, p340, 36):
347+ // /
348+ // / ```
349+ // / The effect of the reduction clause is as if it is applied to all leaf
350+ // / constructs that permit the clause, except for the following constructs:
351+ // / * ....
352+ // / * The teams construct, when combined with the loop construct.
353+ // / ```
354+ // /
355+ // / Therefore, for a combined directive similar to: `!$omp teams loop
356+ // / reduction(...)`, the earlier stages of the compiler assign the `reduction`
357+ // / clauses only to the `loop` leaf and not to the `teams` leaf.
358+ // /
359+ // / On the other hand, if we have a combined construct similar to: `!$omp teams
360+ // / distribute parallel do`, the `reduction` clauses are assigned both to the
361+ // / `teams` and the `do` leaves. We need to match this behavior when we convert
362+ // / `teams` op with a nested `loop` op since the target set of constructs/ops
363+ // / will be incorrect without moving the reductions up to the `teams` op as
364+ // / well.
365+ // /
366+ // / This pattern does exactly this. Given the following input:
367+ // / ```
368+ // / omp.teams {
369+ // / omp.loop reduction(@red_sym %red_op -> %red_arg : !fir.ref<i32>) {
370+ // / omp.loop_nest ... {
371+ // / ...
372+ // / }
373+ // / }
374+ // / }
375+ // / ```
376+ // / this pattern updates the `omp.teams` op in-place to:
377+ // / ```
378+ // / omp.teams reduction(@red_sym %red_op -> %teams_red_arg : !fir.ref<i32>) {
379+ // / omp.loop reduction(@red_sym %teams_red_arg -> %red_arg : !fir.ref<i32>) {
380+ // / omp.loop_nest ... {
381+ // / ...
382+ // / }
383+ // / }
384+ // / }
385+ // / ```
386+ // /
387+ // / Note the following:
388+ // / * The nested `omp.loop` is not rewritten by this pattern, this happens
389+ // / through `GenericLoopConversionPattern`.
390+ // / * The reduction info are cloned from the nested `omp.loop` op to the parent
391+ // / `omp.teams` op.
392+ // / * The reduction operand of the `omp.loop` op is updated to be the **new**
393+ // / reduction block argument of the `omp.teams` op.
394+ class ReductionsHoistingPattern
395+ : public mlir::OpConversionPattern<mlir::omp::TeamsOp> {
396+ public:
397+ using mlir::OpConversionPattern<mlir::omp::TeamsOp>::OpConversionPattern;
398+
399+ static mlir::omp::LoopOp
400+ tryToFindNestedLoopWithReduction (mlir::omp::TeamsOp teamsOp) {
401+ if (teamsOp.getRegion ().getBlocks ().size () != 1 )
402+ return nullptr ;
403+
404+ mlir::Block &teamsBlock = *teamsOp.getRegion ().begin ();
405+ auto loopOpIter = llvm::find_if (teamsBlock, [](mlir::Operation &op) {
406+ auto nestedLoopOp = llvm::dyn_cast<mlir::omp::LoopOp>(&op);
407+
408+ if (!nestedLoopOp)
409+ return false ;
410+
411+ return !nestedLoopOp.getReductionVars ().empty ();
412+ });
413+
414+ if (loopOpIter == teamsBlock.end ())
415+ return nullptr ;
416+
417+ // TODO return error if more than one loop op is nested. We need to
418+ // coalesce reductions in this case.
419+ return llvm::cast<mlir::omp::LoopOp>(loopOpIter);
420+ }
421+
422+ mlir::LogicalResult
423+ matchAndRewrite (mlir::omp::TeamsOp teamsOp, OpAdaptor adaptor,
424+ mlir::ConversionPatternRewriter &rewriter) const override {
425+ mlir::omp::LoopOp nestedLoopOp = tryToFindNestedLoopWithReduction (teamsOp);
426+
427+ rewriter.modifyOpInPlace (teamsOp, [&]() {
428+ teamsOp.setReductionMod (nestedLoopOp.getReductionMod ());
429+ teamsOp.getReductionVarsMutable ().assign (nestedLoopOp.getReductionVars ());
430+ teamsOp.setReductionByref (nestedLoopOp.getReductionByref ());
431+ teamsOp.setReductionSymsAttr (nestedLoopOp.getReductionSymsAttr ());
432+
433+ auto blockArgIface =
434+ llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(*teamsOp);
435+ unsigned reductionArgsStart = blockArgIface.getPrivateBlockArgsStart () +
436+ blockArgIface.numPrivateBlockArgs ();
437+ llvm::SmallVector<mlir::Value> newLoopOpReductionOperands;
438+
439+ for (auto [idx, reductionVar] :
440+ llvm::enumerate (nestedLoopOp.getReductionVars ())) {
441+ mlir::BlockArgument newTeamsOpReductionBlockArg =
442+ teamsOp.getRegion ().insertArgument (reductionArgsStart + idx,
443+ reductionVar.getType (),
444+ reductionVar.getLoc ());
445+ newLoopOpReductionOperands.push_back (newTeamsOpReductionBlockArg);
446+ }
447+
448+ nestedLoopOp.getReductionVarsMutable ().assign (newLoopOpReductionOperands);
449+ });
450+
451+ return mlir::success ();
452+ }
453+ };
454+
346455class GenericLoopConversionPass
347456 : public flangomp::impl::GenericLoopConversionPassBase<
348457 GenericLoopConversionPass> {
@@ -357,11 +466,23 @@ class GenericLoopConversionPass
357466
358467 mlir::MLIRContext *context = &getContext ();
359468 mlir::RewritePatternSet patterns (context);
360- patterns.insert <GenericLoopConversionPattern>(context);
469+ patterns.insert <ReductionsHoistingPattern, GenericLoopConversionPattern>(
470+ context);
361471 mlir::ConversionTarget target (*context);
362472
363473 target.markUnknownOpDynamicallyLegal (
364474 [](mlir::Operation *) { return true ; });
475+
476+ target.addDynamicallyLegalOp <mlir::omp::TeamsOp>(
477+ [](mlir::omp::TeamsOp teamsOp) {
478+ // If teamsOp's reductions are already populated, then the op is
479+ // legal. Additionally, the op is legal if it does not nest a LoopOp
480+ // with reductions.
481+ return !teamsOp.getReductionVars ().empty () ||
482+ ReductionsHoistingPattern::tryToFindNestedLoopWithReduction (
483+ teamsOp) == nullptr ;
484+ });
485+
365486 target.addDynamicallyLegalOp <mlir::omp::LoopOp>(
366487 [](mlir::omp::LoopOp loopOp) {
367488 return mlir::failed (
0 commit comments