@@ -141,6 +141,9 @@ void collectLoopLiveIns(fir::DoConcurrentLoopOp loop,
141141
142142 for (mlir::Value local : loop.getLocalVars ())
143143 liveIns.push_back (local);
144+
145+ for (mlir::Value reduce : loop.getReduceVars ())
146+ liveIns.push_back (reduce);
144147}
145148
146149// / Collects values that are local to a loop: "loop-local values". A loop-local
@@ -319,7 +322,7 @@ class DoConcurrentConversion
319322 targetOp =
320323 genTargetOp (doLoop.getLoc (), rewriter, mapper, loopNestLiveIns,
321324 targetClauseOps, loopNestClauseOps, liveInShapeInfoMap);
322- genTeamsOp (doLoop. getLoc (), rewriter );
325+ genTeamsOp (rewriter, loop, mapper );
323326 }
324327
325328 mlir::omp::ParallelOp parallelOp =
@@ -492,46 +495,7 @@ class DoConcurrentConversion
492495 if (!mapToDevice)
493496 genPrivatizers (rewriter, mapper, loop, wsloopClauseOps);
494497
495- if (!loop.getReduceVars ().empty ()) {
496- for (auto [op, byRef, sym, arg] : llvm::zip_equal (
497- loop.getReduceVars (), loop.getReduceByrefAttr ().asArrayRef (),
498- loop.getReduceSymsAttr ().getAsRange <mlir::SymbolRefAttr>(),
499- loop.getRegionReduceArgs ())) {
500- auto firReducer = moduleSymbolTable.lookup <fir::DeclareReductionOp>(
501- sym.getLeafReference ());
502-
503- mlir::OpBuilder::InsertionGuard guard (rewriter);
504- rewriter.setInsertionPointAfter (firReducer);
505- std::string ompReducerName = sym.getLeafReference ().str () + " .omp" ;
506-
507- auto ompReducer =
508- moduleSymbolTable.lookup <mlir::omp::DeclareReductionOp>(
509- rewriter.getStringAttr (ompReducerName));
510-
511- if (!ompReducer) {
512- ompReducer = mlir::omp::DeclareReductionOp::create (
513- rewriter, firReducer.getLoc (), ompReducerName,
514- firReducer.getTypeAttr ().getValue ());
515-
516- cloneFIRRegionToOMP (rewriter, firReducer.getAllocRegion (),
517- ompReducer.getAllocRegion ());
518- cloneFIRRegionToOMP (rewriter, firReducer.getInitializerRegion (),
519- ompReducer.getInitializerRegion ());
520- cloneFIRRegionToOMP (rewriter, firReducer.getReductionRegion (),
521- ompReducer.getReductionRegion ());
522- cloneFIRRegionToOMP (rewriter, firReducer.getAtomicReductionRegion (),
523- ompReducer.getAtomicReductionRegion ());
524- cloneFIRRegionToOMP (rewriter, firReducer.getCleanupRegion (),
525- ompReducer.getCleanupRegion ());
526- moduleSymbolTable.insert (ompReducer);
527- }
528-
529- wsloopClauseOps.reductionVars .push_back (op);
530- wsloopClauseOps.reductionByref .push_back (byRef);
531- wsloopClauseOps.reductionSyms .push_back (
532- mlir::SymbolRefAttr::get (ompReducer));
533- }
534- }
498+ genReductions (rewriter, mapper, loop, wsloopClauseOps);
535499
536500 auto wsloopOp =
537501 mlir::omp::WsloopOp::create (rewriter, loop.getLoc (), wsloopClauseOps);
@@ -553,8 +517,6 @@ class DoConcurrentConversion
553517
554518 rewriter.setInsertionPointToEnd (&loopNestOp.getRegion ().back ());
555519 mlir::omp::YieldOp::create (rewriter, loop->getLoc ());
556- loop->getParentOfType <mlir::ModuleOp>().print (
557- llvm::errs (), mlir::OpPrintingFlags ().assumeVerified ());
558520
559521 return {loopNestOp, wsloopOp};
560522 }
@@ -778,15 +740,26 @@ class DoConcurrentConversion
778740 liveInName, shape);
779741 }
780742
781- mlir::omp::TeamsOp
782- genTeamsOp (mlir::Location loc,
783- mlir::ConversionPatternRewriter &rewriter) const {
784- auto teamsOp = rewriter.create <mlir::omp::TeamsOp>(
785- loc, /* clauses=*/ mlir::omp::TeamsOperands{});
743+ mlir::omp::TeamsOp genTeamsOp (mlir::ConversionPatternRewriter &rewriter,
744+ fir::DoConcurrentLoopOp loop,
745+ mlir::IRMapping &mapper) const {
746+ mlir::omp::TeamsOperands teamsOps;
747+ genReductions (rewriter, mapper, loop, teamsOps);
748+
749+ mlir::Location loc = loop.getLoc ();
750+ auto teamsOp = rewriter.create <mlir::omp::TeamsOp>(loc, teamsOps);
751+ Fortran::common::openmp::EntryBlockArgs teamsArgs;
752+ teamsArgs.reduction .vars = teamsOps.reductionVars ;
753+ Fortran::common::openmp::genEntryBlock (rewriter, teamsArgs,
754+ teamsOp.getRegion ());
786755
787- rewriter.createBlock (&teamsOp.getRegion ());
788756 rewriter.setInsertionPoint (rewriter.create <mlir::omp::TerminatorOp>(loc));
789757
758+ for (auto [loopVar, teamsArg] : llvm::zip_equal (
759+ loop.getReduceVars (), teamsOp.getRegion ().getArguments ())) {
760+ mapper.map (loopVar, teamsArg);
761+ }
762+
790763 return teamsOp;
791764 }
792765
@@ -861,6 +834,52 @@ class DoConcurrentConversion
861834 }
862835 }
863836
837+ void genReductions (mlir::ConversionPatternRewriter &rewriter,
838+ mlir::IRMapping &mapper, fir::DoConcurrentLoopOp loop,
839+ mlir::omp::ReductionClauseOps &reductionClauseOps) const {
840+ if (!loop.getReduceVars ().empty ()) {
841+ for (auto [var, byRef, sym, arg] : llvm::zip_equal (
842+ loop.getReduceVars (), loop.getReduceByrefAttr ().asArrayRef (),
843+ loop.getReduceSymsAttr ().getAsRange <mlir::SymbolRefAttr>(),
844+ loop.getRegionReduceArgs ())) {
845+ auto firReducer = moduleSymbolTable.lookup <fir::DeclareReductionOp>(
846+ sym.getLeafReference ());
847+
848+ mlir::OpBuilder::InsertionGuard guard (rewriter);
849+ rewriter.setInsertionPointAfter (firReducer);
850+ std::string ompReducerName = sym.getLeafReference ().str () + " .omp" ;
851+
852+ auto ompReducer =
853+ moduleSymbolTable.lookup <mlir::omp::DeclareReductionOp>(
854+ rewriter.getStringAttr (ompReducerName));
855+
856+ if (!ompReducer) {
857+ ompReducer = mlir::omp::DeclareReductionOp::create (
858+ rewriter, firReducer.getLoc (), ompReducerName,
859+ firReducer.getTypeAttr ().getValue ());
860+
861+ cloneFIRRegionToOMP (rewriter, firReducer.getAllocRegion (),
862+ ompReducer.getAllocRegion ());
863+ cloneFIRRegionToOMP (rewriter, firReducer.getInitializerRegion (),
864+ ompReducer.getInitializerRegion ());
865+ cloneFIRRegionToOMP (rewriter, firReducer.getReductionRegion (),
866+ ompReducer.getReductionRegion ());
867+ cloneFIRRegionToOMP (rewriter, firReducer.getAtomicReductionRegion (),
868+ ompReducer.getAtomicReductionRegion ());
869+ cloneFIRRegionToOMP (rewriter, firReducer.getCleanupRegion (),
870+ ompReducer.getCleanupRegion ());
871+ moduleSymbolTable.insert (ompReducer);
872+ }
873+
874+ reductionClauseOps.reductionVars .push_back (
875+ mapToDevice ? mapper.lookup (var) : var);
876+ reductionClauseOps.reductionByref .push_back (byRef);
877+ reductionClauseOps.reductionSyms .push_back (
878+ mlir::SymbolRefAttr::get (ompReducer));
879+ }
880+ }
881+ }
882+
864883 bool mapToDevice;
865884 llvm::DenseSet<fir::DoConcurrentOp> &concurrentLoopsToSkip;
866885 mlir::SymbolTable &moduleSymbolTable;
0 commit comments