@@ -141,6 +141,9 @@ void collectLoopLiveIns(fir::DoConcurrentLoopOp loop,
141
141
142
142
for (mlir::Value local : loop.getLocalVars ())
143
143
liveIns.push_back (local);
144
+
145
+ for (mlir::Value reduce : loop.getReduceVars ())
146
+ liveIns.push_back (reduce);
144
147
}
145
148
146
149
// / Collects values that are local to a loop: "loop-local values". A loop-local
@@ -319,7 +322,7 @@ class DoConcurrentConversion
319
322
targetOp =
320
323
genTargetOp (doLoop.getLoc (), rewriter, mapper, loopNestLiveIns,
321
324
targetClauseOps, loopNestClauseOps, liveInShapeInfoMap);
322
- genTeamsOp (doLoop. getLoc (), rewriter );
325
+ genTeamsOp (rewriter, loop, mapper );
323
326
}
324
327
325
328
mlir::omp::ParallelOp parallelOp =
@@ -492,46 +495,7 @@ class DoConcurrentConversion
492
495
if (!mapToDevice)
493
496
genPrivatizers (rewriter, mapper, loop, wsloopClauseOps);
494
497
495
- if (!loop.getReduceVars ().empty ()) {
496
- for (auto [op, byRef, sym, arg] : llvm::zip_equal (
497
- loop.getReduceVars (), loop.getReduceByrefAttr ().asArrayRef (),
498
- loop.getReduceSymsAttr ().getAsRange <mlir::SymbolRefAttr>(),
499
- loop.getRegionReduceArgs ())) {
500
- auto firReducer = moduleSymbolTable.lookup <fir::DeclareReductionOp>(
501
- sym.getLeafReference ());
502
-
503
- mlir::OpBuilder::InsertionGuard guard (rewriter);
504
- rewriter.setInsertionPointAfter (firReducer);
505
- std::string ompReducerName = sym.getLeafReference ().str () + " .omp" ;
506
-
507
- auto ompReducer =
508
- moduleSymbolTable.lookup <mlir::omp::DeclareReductionOp>(
509
- rewriter.getStringAttr (ompReducerName));
510
-
511
- if (!ompReducer) {
512
- ompReducer = mlir::omp::DeclareReductionOp::create (
513
- rewriter, firReducer.getLoc (), ompReducerName,
514
- firReducer.getTypeAttr ().getValue ());
515
-
516
- cloneFIRRegionToOMP (rewriter, firReducer.getAllocRegion (),
517
- ompReducer.getAllocRegion ());
518
- cloneFIRRegionToOMP (rewriter, firReducer.getInitializerRegion (),
519
- ompReducer.getInitializerRegion ());
520
- cloneFIRRegionToOMP (rewriter, firReducer.getReductionRegion (),
521
- ompReducer.getReductionRegion ());
522
- cloneFIRRegionToOMP (rewriter, firReducer.getAtomicReductionRegion (),
523
- ompReducer.getAtomicReductionRegion ());
524
- cloneFIRRegionToOMP (rewriter, firReducer.getCleanupRegion (),
525
- ompReducer.getCleanupRegion ());
526
- moduleSymbolTable.insert (ompReducer);
527
- }
528
-
529
- wsloopClauseOps.reductionVars .push_back (op);
530
- wsloopClauseOps.reductionByref .push_back (byRef);
531
- wsloopClauseOps.reductionSyms .push_back (
532
- mlir::SymbolRefAttr::get (ompReducer));
533
- }
534
- }
498
+ genReductions (rewriter, mapper, loop, wsloopClauseOps);
535
499
536
500
auto wsloopOp =
537
501
mlir::omp::WsloopOp::create (rewriter, loop.getLoc (), wsloopClauseOps);
@@ -553,8 +517,6 @@ class DoConcurrentConversion
553
517
554
518
rewriter.setInsertionPointToEnd (&loopNestOp.getRegion ().back ());
555
519
mlir::omp::YieldOp::create (rewriter, loop->getLoc ());
556
- loop->getParentOfType <mlir::ModuleOp>().print (
557
- llvm::errs (), mlir::OpPrintingFlags ().assumeVerified ());
558
520
559
521
return {loopNestOp, wsloopOp};
560
522
}
@@ -778,15 +740,26 @@ class DoConcurrentConversion
778
740
liveInName, shape);
779
741
}
780
742
781
- mlir::omp::TeamsOp
782
- genTeamsOp (mlir::Location loc,
783
- mlir::ConversionPatternRewriter &rewriter) const {
784
- auto teamsOp = rewriter.create <mlir::omp::TeamsOp>(
785
- loc, /* clauses=*/ mlir::omp::TeamsOperands{});
743
+ mlir::omp::TeamsOp genTeamsOp (mlir::ConversionPatternRewriter &rewriter,
744
+ fir::DoConcurrentLoopOp loop,
745
+ mlir::IRMapping &mapper) const {
746
+ mlir::omp::TeamsOperands teamsOps;
747
+ genReductions (rewriter, mapper, loop, teamsOps);
748
+
749
+ mlir::Location loc = loop.getLoc ();
750
+ auto teamsOp = rewriter.create <mlir::omp::TeamsOp>(loc, teamsOps);
751
+ Fortran::common::openmp::EntryBlockArgs teamsArgs;
752
+ teamsArgs.reduction .vars = teamsOps.reductionVars ;
753
+ Fortran::common::openmp::genEntryBlock (rewriter, teamsArgs,
754
+ teamsOp.getRegion ());
786
755
787
- rewriter.createBlock (&teamsOp.getRegion ());
788
756
rewriter.setInsertionPoint (rewriter.create <mlir::omp::TerminatorOp>(loc));
789
757
758
+ for (auto [loopVar, teamsArg] : llvm::zip_equal (
759
+ loop.getReduceVars (), teamsOp.getRegion ().getArguments ())) {
760
+ mapper.map (loopVar, teamsArg);
761
+ }
762
+
790
763
return teamsOp;
791
764
}
792
765
@@ -861,6 +834,52 @@ class DoConcurrentConversion
861
834
}
862
835
}
863
836
837
+ void genReductions (mlir::ConversionPatternRewriter &rewriter,
838
+ mlir::IRMapping &mapper, fir::DoConcurrentLoopOp loop,
839
+ mlir::omp::ReductionClauseOps &reductionClauseOps) const {
840
+ if (!loop.getReduceVars ().empty ()) {
841
+ for (auto [var, byRef, sym, arg] : llvm::zip_equal (
842
+ loop.getReduceVars (), loop.getReduceByrefAttr ().asArrayRef (),
843
+ loop.getReduceSymsAttr ().getAsRange <mlir::SymbolRefAttr>(),
844
+ loop.getRegionReduceArgs ())) {
845
+ auto firReducer = moduleSymbolTable.lookup <fir::DeclareReductionOp>(
846
+ sym.getLeafReference ());
847
+
848
+ mlir::OpBuilder::InsertionGuard guard (rewriter);
849
+ rewriter.setInsertionPointAfter (firReducer);
850
+ std::string ompReducerName = sym.getLeafReference ().str () + " .omp" ;
851
+
852
+ auto ompReducer =
853
+ moduleSymbolTable.lookup <mlir::omp::DeclareReductionOp>(
854
+ rewriter.getStringAttr (ompReducerName));
855
+
856
+ if (!ompReducer) {
857
+ ompReducer = mlir::omp::DeclareReductionOp::create (
858
+ rewriter, firReducer.getLoc (), ompReducerName,
859
+ firReducer.getTypeAttr ().getValue ());
860
+
861
+ cloneFIRRegionToOMP (rewriter, firReducer.getAllocRegion (),
862
+ ompReducer.getAllocRegion ());
863
+ cloneFIRRegionToOMP (rewriter, firReducer.getInitializerRegion (),
864
+ ompReducer.getInitializerRegion ());
865
+ cloneFIRRegionToOMP (rewriter, firReducer.getReductionRegion (),
866
+ ompReducer.getReductionRegion ());
867
+ cloneFIRRegionToOMP (rewriter, firReducer.getAtomicReductionRegion (),
868
+ ompReducer.getAtomicReductionRegion ());
869
+ cloneFIRRegionToOMP (rewriter, firReducer.getCleanupRegion (),
870
+ ompReducer.getCleanupRegion ());
871
+ moduleSymbolTable.insert (ompReducer);
872
+ }
873
+
874
+ reductionClauseOps.reductionVars .push_back (
875
+ mapToDevice ? mapper.lookup (var) : var);
876
+ reductionClauseOps.reductionByref .push_back (byRef);
877
+ reductionClauseOps.reductionSyms .push_back (
878
+ mlir::SymbolRefAttr::get (ompReducer));
879
+ }
880
+ }
881
+ }
882
+
864
883
bool mapToDevice;
865
884
llvm::DenseSet<fir::DoConcurrentOp> &concurrentLoopsToSkip;
866
885
mlir::SymbolTable &moduleSymbolTable;
0 commit comments