@@ -312,6 +312,19 @@ class DoConcurrentConversion
312312 bool isComposite) const {
313313 mlir::omp::WsloopOperands wsloopClauseOps;
314314
315+ auto cloneFIRRegionToOMP = [&rewriter](mlir::Region &firRegion,
316+ mlir::Region &ompRegion) {
317+ if (!firRegion.empty ()) {
318+ rewriter.cloneRegionBefore (firRegion, ompRegion, ompRegion.begin ());
319+ auto firYield =
320+ mlir::cast<fir::YieldOp>(ompRegion.back ().getTerminator ());
321+ rewriter.setInsertionPoint (firYield);
322+ rewriter.create <mlir::omp::YieldOp>(firYield.getLoc (),
323+ firYield.getOperands ());
324+ rewriter.eraseOp (firYield);
325+ }
326+ };
327+
315328 // For `local` (and `local_init`) opernads, emit corresponding `private`
316329 // clauses and attach these clauses to the workshare loop.
317330 if (!loop.getLocalVars ().empty ())
@@ -326,50 +339,65 @@ class DoConcurrentConversion
326339 TODO (localizer.getLoc (),
327340 " local_init conversion is not supported yet" );
328341
329- auto oldIP = rewriter. saveInsertionPoint ( );
342+ mlir::OpBuilder::InsertionGuard guard (rewriter );
330343 rewriter.setInsertionPointAfter (localizer);
344+
331345 auto privatizer = rewriter.create <mlir::omp::PrivateClauseOp>(
332346 localizer.getLoc (), sym.getLeafReference ().str () + " .omp" ,
333347 localizer.getTypeAttr ().getValue (),
334348 mlir::omp::DataSharingClauseType::Private);
335349
336- if (!localizer.getInitRegion ().empty ()) {
337- rewriter.cloneRegionBefore (localizer.getInitRegion (),
338- privatizer.getInitRegion (),
339- privatizer.getInitRegion ().begin ());
340- auto firYield = mlir::cast<fir::YieldOp>(
341- privatizer.getInitRegion ().back ().getTerminator ());
342- rewriter.setInsertionPoint (firYield);
343- rewriter.create <mlir::omp::YieldOp>(firYield.getLoc (),
344- firYield.getOperands ());
345- rewriter.eraseOp (firYield);
346- }
347-
348- if (!localizer.getDeallocRegion ().empty ()) {
349- rewriter.cloneRegionBefore (localizer.getDeallocRegion (),
350- privatizer.getDeallocRegion (),
351- privatizer.getDeallocRegion ().begin ());
352- auto firYield = mlir::cast<fir::YieldOp>(
353- privatizer.getDeallocRegion ().back ().getTerminator ());
354- rewriter.setInsertionPoint (firYield);
355- rewriter.create <mlir::omp::YieldOp>(firYield.getLoc (),
356- firYield.getOperands ());
357- rewriter.eraseOp (firYield);
358- }
359-
360- rewriter.restoreInsertionPoint (oldIP);
350+ cloneFIRRegionToOMP (localizer.getInitRegion (),
351+ privatizer.getInitRegion ());
352+ cloneFIRRegionToOMP (localizer.getDeallocRegion (),
353+ privatizer.getDeallocRegion ());
361354
362355 wsloopClauseOps.privateVars .push_back (op);
363356 wsloopClauseOps.privateSyms .push_back (
364357 mlir::SymbolRefAttr::get (privatizer));
365358 }
366359
360+ if (!loop.getReduceVars ().empty ()) {
361+ for (auto [op, byRef, sym, arg] : llvm::zip_equal (
362+ loop.getReduceVars (), loop.getReduceByrefAttr ().asArrayRef (),
363+ loop.getReduceSymsAttr ().getAsRange <mlir::SymbolRefAttr>(),
364+ loop.getRegionReduceArgs ())) {
365+ auto firReducer =
366+ mlir::SymbolTable::lookupNearestSymbolFrom<fir::DeclareReductionOp>(
367+ loop, sym);
368+
369+ mlir::OpBuilder::InsertionGuard guard (rewriter);
370+ rewriter.setInsertionPointAfter (firReducer);
371+
372+ auto ompReducer = rewriter.create <mlir::omp::DeclareReductionOp>(
373+ firReducer.getLoc (), sym.getLeafReference ().str () + " .omp" ,
374+ firReducer.getTypeAttr ().getValue ());
375+
376+ cloneFIRRegionToOMP (firReducer.getAllocRegion (),
377+ ompReducer.getAllocRegion ());
378+ cloneFIRRegionToOMP (firReducer.getInitializerRegion (),
379+ ompReducer.getInitializerRegion ());
380+ cloneFIRRegionToOMP (firReducer.getReductionRegion (),
381+ ompReducer.getReductionRegion ());
382+ cloneFIRRegionToOMP (firReducer.getAtomicReductionRegion (),
383+ ompReducer.getAtomicReductionRegion ());
384+ cloneFIRRegionToOMP (firReducer.getCleanupRegion (),
385+ ompReducer.getCleanupRegion ());
386+
387+ wsloopClauseOps.reductionVars .push_back (op);
388+ wsloopClauseOps.reductionByref .push_back (byRef);
389+ wsloopClauseOps.reductionSyms .push_back (
390+ mlir::SymbolRefAttr::get (ompReducer));
391+ }
392+ }
393+
367394 auto wsloopOp =
368395 rewriter.create <mlir::omp::WsloopOp>(loop.getLoc (), wsloopClauseOps);
369396 wsloopOp.setComposite (isComposite);
370397
371398 Fortran::common::openmp::EntryBlockArgs wsloopArgs;
372399 wsloopArgs.priv .vars = wsloopClauseOps.privateVars ;
400+ wsloopArgs.reduction .vars = wsloopClauseOps.reductionVars ;
373401 Fortran::common::openmp::genEntryBlock (rewriter, wsloopArgs,
374402 wsloopOp.getRegion ());
375403
@@ -393,7 +421,8 @@ class DoConcurrentConversion
393421 clauseOps.loopLowerBounds .size ())))
394422 rewriter.replaceAllUsesWith (loopNestArg, wsloopArg);
395423
396- for (unsigned i = 0 ; i < loop.getLocalVars ().size (); ++i)
424+ for (unsigned i = 0 ;
425+ i < loop.getLocalVars ().size () + loop.getReduceVars ().size (); ++i)
397426 loopNestOp.getRegion ().eraseArgument (clauseOps.loopLowerBounds .size ());
398427
399428 return loopNestOp;
0 commit comments