@@ -312,6 +312,19 @@ class DoConcurrentConversion
312312              bool  isComposite) const  {
313313    mlir::omp::WsloopOperands wsloopClauseOps;
314314
315+     auto  cloneFIRRegionToOMP = [&rewriter](mlir::Region &firRegion,
316+                                            mlir::Region &ompRegion) {
317+       if  (!firRegion.empty ()) {
318+         rewriter.cloneRegionBefore (firRegion, ompRegion, ompRegion.begin ());
319+         auto  firYield =
320+             mlir::cast<fir::YieldOp>(ompRegion.back ().getTerminator ());
321+         rewriter.setInsertionPoint (firYield);
322+         rewriter.create <mlir::omp::YieldOp>(firYield.getLoc (),
323+                                             firYield.getOperands ());
324+         rewriter.eraseOp (firYield);
325+       }
326+     };
327+ 
315328    //  For `local` (and `local_init`) opernads, emit corresponding `private`
316329    //  clauses and attach these clauses to the workshare loop.
317330    if  (!loop.getLocalVars ().empty ())
@@ -326,50 +339,65 @@ class DoConcurrentConversion
326339          TODO (localizer.getLoc (),
327340               " local_init conversion is not supported yet"  );
328341
329-         auto  oldIP = rewriter. saveInsertionPoint ( );
342+         mlir::OpBuilder::InsertionGuard  guard (rewriter );
330343        rewriter.setInsertionPointAfter (localizer);
344+ 
331345        auto  privatizer = rewriter.create <mlir::omp::PrivateClauseOp>(
332346            localizer.getLoc (), sym.getLeafReference ().str () + " .omp"  ,
333347            localizer.getTypeAttr ().getValue (),
334348            mlir::omp::DataSharingClauseType::Private);
335349
336-         if  (!localizer.getInitRegion ().empty ()) {
337-           rewriter.cloneRegionBefore (localizer.getInitRegion (),
338-                                      privatizer.getInitRegion (),
339-                                      privatizer.getInitRegion ().begin ());
340-           auto  firYield = mlir::cast<fir::YieldOp>(
341-               privatizer.getInitRegion ().back ().getTerminator ());
342-           rewriter.setInsertionPoint (firYield);
343-           rewriter.create <mlir::omp::YieldOp>(firYield.getLoc (),
344-                                               firYield.getOperands ());
345-           rewriter.eraseOp (firYield);
346-         }
347- 
348-         if  (!localizer.getDeallocRegion ().empty ()) {
349-           rewriter.cloneRegionBefore (localizer.getDeallocRegion (),
350-                                      privatizer.getDeallocRegion (),
351-                                      privatizer.getDeallocRegion ().begin ());
352-           auto  firYield = mlir::cast<fir::YieldOp>(
353-               privatizer.getDeallocRegion ().back ().getTerminator ());
354-           rewriter.setInsertionPoint (firYield);
355-           rewriter.create <mlir::omp::YieldOp>(firYield.getLoc (),
356-                                               firYield.getOperands ());
357-           rewriter.eraseOp (firYield);
358-         }
359- 
360-         rewriter.restoreInsertionPoint (oldIP);
350+         cloneFIRRegionToOMP (localizer.getInitRegion (),
351+                             privatizer.getInitRegion ());
352+         cloneFIRRegionToOMP (localizer.getDeallocRegion (),
353+                             privatizer.getDeallocRegion ());
361354
362355        wsloopClauseOps.privateVars .push_back (op);
363356        wsloopClauseOps.privateSyms .push_back (
364357            mlir::SymbolRefAttr::get (privatizer));
365358      }
366359
360+     if  (!loop.getReduceVars ().empty ()) {
361+       for  (auto  [op, byRef, sym, arg] : llvm::zip_equal (
362+                loop.getReduceVars (), loop.getReduceByrefAttr ().asArrayRef (),
363+                loop.getReduceSymsAttr ().getAsRange <mlir::SymbolRefAttr>(),
364+                loop.getRegionReduceArgs ())) {
365+         auto  firReducer =
366+             mlir::SymbolTable::lookupNearestSymbolFrom<fir::DeclareReductionOp>(
367+                 loop, sym);
368+ 
369+         mlir::OpBuilder::InsertionGuard guard (rewriter);
370+         rewriter.setInsertionPointAfter (firReducer);
371+ 
372+         auto  ompReducer = rewriter.create <mlir::omp::DeclareReductionOp>(
373+             firReducer.getLoc (), sym.getLeafReference ().str () + " .omp"  ,
374+             firReducer.getTypeAttr ().getValue ());
375+ 
376+         cloneFIRRegionToOMP (firReducer.getAllocRegion (),
377+                             ompReducer.getAllocRegion ());
378+         cloneFIRRegionToOMP (firReducer.getInitializerRegion (),
379+                             ompReducer.getInitializerRegion ());
380+         cloneFIRRegionToOMP (firReducer.getReductionRegion (),
381+                             ompReducer.getReductionRegion ());
382+         cloneFIRRegionToOMP (firReducer.getAtomicReductionRegion (),
383+                             ompReducer.getAtomicReductionRegion ());
384+         cloneFIRRegionToOMP (firReducer.getCleanupRegion (),
385+                             ompReducer.getCleanupRegion ());
386+ 
387+         wsloopClauseOps.reductionVars .push_back (op);
388+         wsloopClauseOps.reductionByref .push_back (byRef);
389+         wsloopClauseOps.reductionSyms .push_back (
390+             mlir::SymbolRefAttr::get (ompReducer));
391+       }
392+     }
393+ 
367394    auto  wsloopOp =
368395        rewriter.create <mlir::omp::WsloopOp>(loop.getLoc (), wsloopClauseOps);
369396    wsloopOp.setComposite (isComposite);
370397
371398    Fortran::common::openmp::EntryBlockArgs wsloopArgs;
372399    wsloopArgs.priv .vars  = wsloopClauseOps.privateVars ;
400+     wsloopArgs.reduction .vars  = wsloopClauseOps.reductionVars ;
373401    Fortran::common::openmp::genEntryBlock (rewriter, wsloopArgs,
374402                                           wsloopOp.getRegion ());
375403
@@ -393,7 +421,8 @@ class DoConcurrentConversion
393421                             clauseOps.loopLowerBounds .size ())))
394422      rewriter.replaceAllUsesWith (loopNestArg, wsloopArg);
395423
396-     for  (unsigned  i = 0 ; i < loop.getLocalVars ().size (); ++i)
424+     for  (unsigned  i = 0 ;
425+          i < loop.getLocalVars ().size () + loop.getReduceVars ().size (); ++i)
397426      loopNestOp.getRegion ().eraseArgument (clauseOps.loopLowerBounds .size ());
398427
399428    return  loopNestOp;
0 commit comments