@@ -453,6 +453,19 @@ class DoConcurrentConversion
453453 bool isComposite) const {
454454 mlir::omp::WsloopOperands wsloopClauseOps;
455455
456+ auto cloneFIRRegionToOMP = [&rewriter](mlir::Region &firRegion,
457+ mlir::Region &ompRegion) {
458+ if (!firRegion.empty ()) {
459+ rewriter.cloneRegionBefore (firRegion, ompRegion, ompRegion.begin ());
460+ auto firYield =
461+ mlir::cast<fir::YieldOp>(ompRegion.back ().getTerminator ());
462+ rewriter.setInsertionPoint (firYield);
463+ rewriter.create <mlir::omp::YieldOp>(firYield.getLoc (),
464+ firYield.getOperands ());
465+ rewriter.eraseOp (firYield);
466+ }
467+ };
468+
456469 // For `local` (and `local_init`) opernads, emit corresponding `private`
457470 // clauses and attach these clauses to the workshare loop.
458471 if (!loop.getLocalVars ().empty ())
@@ -467,50 +480,65 @@ class DoConcurrentConversion
467480 TODO (localizer.getLoc (),
468481 " local_init conversion is not supported yet" );
469482
470- auto oldIP = rewriter. saveInsertionPoint ( );
483+ mlir::OpBuilder::InsertionGuard guard (rewriter );
471484 rewriter.setInsertionPointAfter (localizer);
485+
472486 auto privatizer = rewriter.create <mlir::omp::PrivateClauseOp>(
473487 localizer.getLoc (), sym.getLeafReference ().str () + " .omp" ,
474488 localizer.getTypeAttr ().getValue (),
475489 mlir::omp::DataSharingClauseType::Private);
476490
477- if (!localizer.getInitRegion ().empty ()) {
478- rewriter.cloneRegionBefore (localizer.getInitRegion (),
479- privatizer.getInitRegion (),
480- privatizer.getInitRegion ().begin ());
481- auto firYield = mlir::cast<fir::YieldOp>(
482- privatizer.getInitRegion ().back ().getTerminator ());
483- rewriter.setInsertionPoint (firYield);
484- rewriter.create <mlir::omp::YieldOp>(firYield.getLoc (),
485- firYield.getOperands ());
486- rewriter.eraseOp (firYield);
487- }
488-
489- if (!localizer.getDeallocRegion ().empty ()) {
490- rewriter.cloneRegionBefore (localizer.getDeallocRegion (),
491- privatizer.getDeallocRegion (),
492- privatizer.getDeallocRegion ().begin ());
493- auto firYield = mlir::cast<fir::YieldOp>(
494- privatizer.getDeallocRegion ().back ().getTerminator ());
495- rewriter.setInsertionPoint (firYield);
496- rewriter.create <mlir::omp::YieldOp>(firYield.getLoc (),
497- firYield.getOperands ());
498- rewriter.eraseOp (firYield);
499- }
500-
501- rewriter.restoreInsertionPoint (oldIP);
491+ cloneFIRRegionToOMP (localizer.getInitRegion (),
492+ privatizer.getInitRegion ());
493+ cloneFIRRegionToOMP (localizer.getDeallocRegion (),
494+ privatizer.getDeallocRegion ());
502495
503496 wsloopClauseOps.privateVars .push_back (op);
504497 wsloopClauseOps.privateSyms .push_back (
505498 mlir::SymbolRefAttr::get (privatizer));
506499 }
507500
501+ if (!loop.getReduceVars ().empty ()) {
502+ for (auto [op, byRef, sym, arg] : llvm::zip_equal (
503+ loop.getReduceVars (), loop.getReduceByrefAttr ().asArrayRef (),
504+ loop.getReduceSymsAttr ().getAsRange <mlir::SymbolRefAttr>(),
505+ loop.getRegionReduceArgs ())) {
506+ auto firReducer =
507+ mlir::SymbolTable::lookupNearestSymbolFrom<fir::DeclareReductionOp>(
508+ loop, sym);
509+
510+ mlir::OpBuilder::InsertionGuard guard (rewriter);
511+ rewriter.setInsertionPointAfter (firReducer);
512+
513+ auto ompReducer = rewriter.create <mlir::omp::DeclareReductionOp>(
514+ firReducer.getLoc (), sym.getLeafReference ().str () + " .omp" ,
515+ firReducer.getTypeAttr ().getValue ());
516+
517+ cloneFIRRegionToOMP (firReducer.getAllocRegion (),
518+ ompReducer.getAllocRegion ());
519+ cloneFIRRegionToOMP (firReducer.getInitializerRegion (),
520+ ompReducer.getInitializerRegion ());
521+ cloneFIRRegionToOMP (firReducer.getReductionRegion (),
522+ ompReducer.getReductionRegion ());
523+ cloneFIRRegionToOMP (firReducer.getAtomicReductionRegion (),
524+ ompReducer.getAtomicReductionRegion ());
525+ cloneFIRRegionToOMP (firReducer.getCleanupRegion (),
526+ ompReducer.getCleanupRegion ());
527+
528+ wsloopClauseOps.reductionVars .push_back (op);
529+ wsloopClauseOps.reductionByref .push_back (byRef);
530+ wsloopClauseOps.reductionSyms .push_back (
531+ mlir::SymbolRefAttr::get (ompReducer));
532+ }
533+ }
534+
508535 auto wsloopOp =
509536 rewriter.create <mlir::omp::WsloopOp>(loop.getLoc (), wsloopClauseOps);
510537 wsloopOp.setComposite (isComposite);
511538
512539 Fortran::common::openmp::EntryBlockArgs wsloopArgs;
513540 wsloopArgs.priv .vars = wsloopClauseOps.privateVars ;
541+ wsloopArgs.reduction .vars = wsloopClauseOps.reductionVars ;
514542 Fortran::common::openmp::genEntryBlock (rewriter, wsloopArgs,
515543 wsloopOp.getRegion ());
516544
@@ -534,7 +562,8 @@ class DoConcurrentConversion
534562 clauseOps.loopLowerBounds .size ())))
535563 rewriter.replaceAllUsesWith (loopNestArg, wsloopArg);
536564
537- for (unsigned i = 0 ; i < loop.getLocalVars ().size (); ++i)
565+ for (unsigned i = 0 ;
566+ i < loop.getLocalVars ().size () + loop.getReduceVars ().size (); ++i)
538567 loopNestOp.getRegion ().eraseArgument (clauseOps.loopLowerBounds .size ());
539568
540569 return loopNestOp;
0 commit comments