@@ -137,6 +137,9 @@ void collectLoopLiveIns(fir::DoConcurrentLoopOp loop,
137137
138138 liveIns.push_back (operand->get ());
139139 });
140+
141+ for (mlir::Value local : loop.getLocalVars ())
142+ liveIns.push_back (local);
140143}
141144
142145// / Collects values that are local to a loop: "loop-local values". A loop-local
@@ -251,8 +254,7 @@ class DoConcurrentConversion
251254 .getIsTargetDevice ();
252255
253256 mlir::omp::TargetOperands targetClauseOps;
254- genLoopNestClauseOps (doLoop.getLoc (), rewriter, loop, mapper,
255- loopNestClauseOps,
257+ genLoopNestClauseOps (doLoop.getLoc (), rewriter, loop, loopNestClauseOps,
256258 isTargetDevice ? nullptr : &targetClauseOps);
257259
258260 LiveInShapeInfoMap liveInShapeInfoMap;
@@ -274,14 +276,13 @@ class DoConcurrentConversion
274276 }
275277
276278 mlir::omp::ParallelOp parallelOp =
277- genParallelOp (doLoop. getLoc (), rewriter , ivInfos, mapper);
279+ genParallelOp (rewriter, loop , ivInfos, mapper);
278280
279281 // Only set as composite when part of `distribute parallel do`.
280282 parallelOp.setComposite (mapToDevice);
281283
282284 if (!mapToDevice)
283- genLoopNestClauseOps (doLoop.getLoc (), rewriter, loop, mapper,
284- loopNestClauseOps);
285+ genLoopNestClauseOps (doLoop.getLoc (), rewriter, loop, loopNestClauseOps);
285286
286287 for (mlir::Value local : locals)
287288 looputils::localizeLoopLocalValue (local, parallelOp.getRegion (),
@@ -290,10 +291,38 @@ class DoConcurrentConversion
290291 if (mapToDevice)
291292 genDistributeOp (doLoop.getLoc (), rewriter).setComposite (/* val=*/ true );
292293
293- mlir::omp::LoopNestOp ompLoopNest =
294+ auto [loopNestOp, wsLoopOp] =
294295 genWsLoopOp (rewriter, loop, mapper, loopNestClauseOps,
295296 /* isComposite=*/ mapToDevice);
296297
298+ // `local` region arguments are transferred/cloned from the `do concurrent`
299+ // loop to the loopnest op when the region is cloned above. Instead, these
300+ // region arguments should be on the workshare loop's region.
301+ if (mapToDevice) {
302+ for (auto [parallelArg, loopNestArg] : llvm::zip_equal (
303+ parallelOp.getRegion ().getArguments (),
304+ loopNestOp.getRegion ().getArguments ().slice (
305+ loop.getLocalOperandsStart (), loop.getNumLocalOperands ())))
306+ rewriter.replaceAllUsesWith (loopNestArg, parallelArg);
307+
308+ for (auto [wsloopArg, loopNestArg] : llvm::zip_equal (
309+ wsLoopOp.getRegion ().getArguments (),
310+ loopNestOp.getRegion ().getArguments ().slice (
311+ loop.getReduceOperandsStart (), loop.getNumReduceOperands ())))
312+ rewriter.replaceAllUsesWith (loopNestArg, wsloopArg);
313+ } else {
314+ for (auto [wsloopArg, loopNestArg] :
315+ llvm::zip_equal (wsLoopOp.getRegion ().getArguments (),
316+ loopNestOp.getRegion ().getArguments ().drop_front (
317+ loopNestClauseOps.loopLowerBounds .size ())))
318+ rewriter.replaceAllUsesWith (loopNestArg, wsloopArg);
319+ }
320+
321+ for (unsigned i = 0 ;
322+ i < loop.getLocalVars ().size () + loop.getReduceVars ().size (); ++i)
323+ loopNestOp.getRegion ().eraseArgument (
324+ loopNestClauseOps.loopLowerBounds .size ());
325+
297326 rewriter.setInsertionPoint (doLoop);
298327 fir::FirOpBuilder builder (
299328 rewriter,
@@ -314,7 +343,7 @@ class DoConcurrentConversion
314343 // Mark `unordered` loops that are not perfectly nested to be skipped from
315344 // the legality check of the `ConversionTarget` since we are not interested
316345 // in mapping them to OpenMP.
317- ompLoopNest ->walk ([&](fir::DoConcurrentOp doLoop) {
346+ loopNestOp ->walk ([&](fir::DoConcurrentOp doLoop) {
318347 concurrentLoopsToSkip.insert (doLoop);
319348 });
320349
@@ -370,11 +399,21 @@ class DoConcurrentConversion
370399 llvm::DenseMap<mlir::Value, TargetDeclareShapeCreationInfo>;
371400
372401 mlir::omp::ParallelOp
373- genParallelOp (mlir::Location loc, mlir::ConversionPatternRewriter &rewriter,
402+ genParallelOp (mlir::ConversionPatternRewriter &rewriter,
403+ fir::DoConcurrentLoopOp loop,
374404 looputils::InductionVariableInfos &ivInfos,
375405 mlir::IRMapping &mapper) const {
376- auto parallelOp = mlir::omp::ParallelOp::create (rewriter, loc);
377- rewriter.createBlock (¶llelOp.getRegion ());
406+ mlir::omp::ParallelOperands parallelOps;
407+
408+ if (mapToDevice)
409+ genPrivatizers (rewriter, mapper, loop, parallelOps);
410+
411+ mlir::Location loc = loop.getLoc ();
412+ auto parallelOp = mlir::omp::ParallelOp::create (rewriter, loc, parallelOps);
413+ Fortran::common::openmp::EntryBlockArgs parallelArgs;
414+ parallelArgs.priv .vars = parallelOps.privateVars ;
415+ Fortran::common::openmp::genEntryBlock (rewriter, parallelArgs,
416+ parallelOp.getRegion ());
378417 rewriter.setInsertionPoint (mlir::omp::TerminatorOp::create (rewriter, loc));
379418
380419 genLoopNestIndVarAllocs (rewriter, ivInfos, mapper);
@@ -411,7 +450,7 @@ class DoConcurrentConversion
411450
412451 void genLoopNestClauseOps (
413452 mlir::Location loc, mlir::ConversionPatternRewriter &rewriter,
414- fir::DoConcurrentLoopOp loop, mlir::IRMapping &mapper,
453+ fir::DoConcurrentLoopOp loop,
415454 mlir::omp::LoopNestOperands &loopNestClauseOps,
416455 mlir::omp::TargetOperands *targetClauseOps = nullptr ) const {
417456 assert (loopNestClauseOps.loopLowerBounds .empty () &&
@@ -440,59 +479,14 @@ class DoConcurrentConversion
440479 loopNestClauseOps.loopInclusive = rewriter.getUnitAttr ();
441480 }
442481
443- mlir::omp::LoopNestOp
482+ std::pair< mlir::omp::LoopNestOp, mlir::omp::WsloopOp>
444483 genWsLoopOp (mlir::ConversionPatternRewriter &rewriter,
445484 fir::DoConcurrentLoopOp loop, mlir::IRMapping &mapper,
446485 const mlir::omp::LoopNestOperands &clauseOps,
447486 bool isComposite) const {
448487 mlir::omp::WsloopOperands wsloopClauseOps;
449-
450- auto cloneFIRRegionToOMP = [&rewriter](mlir::Region &firRegion,
451- mlir::Region &ompRegion) {
452- if (!firRegion.empty ()) {
453- rewriter.cloneRegionBefore (firRegion, ompRegion, ompRegion.begin ());
454- auto firYield =
455- mlir::cast<fir::YieldOp>(ompRegion.back ().getTerminator ());
456- rewriter.setInsertionPoint (firYield);
457- mlir::omp::YieldOp::create (rewriter, firYield.getLoc (),
458- firYield.getOperands ());
459- rewriter.eraseOp (firYield);
460- }
461- };
462-
463- // For `local` (and `local_init`) opernads, emit corresponding `private`
464- // clauses and attach these clauses to the workshare loop.
465- if (!loop.getLocalVars ().empty ())
466- for (auto [op, sym, arg] : llvm::zip_equal (
467- loop.getLocalVars (),
468- loop.getLocalSymsAttr ().getAsRange <mlir::SymbolRefAttr>(),
469- loop.getRegionLocalArgs ())) {
470- auto localizer = moduleSymbolTable.lookup <fir::LocalitySpecifierOp>(
471- sym.getLeafReference ());
472- if (localizer.getLocalitySpecifierType () ==
473- fir::LocalitySpecifierType::LocalInit)
474- TODO (localizer.getLoc (),
475- " local_init conversion is not supported yet" );
476-
477- mlir::OpBuilder::InsertionGuard guard (rewriter);
478- rewriter.setInsertionPointAfter (localizer);
479-
480- auto privatizer = mlir::omp::PrivateClauseOp::create (
481- rewriter, localizer.getLoc (), sym.getLeafReference ().str () + " .omp" ,
482- localizer.getTypeAttr ().getValue (),
483- mlir::omp::DataSharingClauseType::Private);
484-
485- cloneFIRRegionToOMP (localizer.getInitRegion (),
486- privatizer.getInitRegion ());
487- cloneFIRRegionToOMP (localizer.getDeallocRegion (),
488- privatizer.getDeallocRegion ());
489-
490- moduleSymbolTable.insert (privatizer);
491-
492- wsloopClauseOps.privateVars .push_back (op);
493- wsloopClauseOps.privateSyms .push_back (
494- mlir::SymbolRefAttr::get (privatizer));
495- }
488+ if (!mapToDevice)
489+ genPrivatizers (rewriter, mapper, loop, wsloopClauseOps);
496490
497491 if (!loop.getReduceVars ().empty ()) {
498492 for (auto [op, byRef, sym, arg] : llvm::zip_equal (
@@ -515,15 +509,15 @@ class DoConcurrentConversion
515509 rewriter, firReducer.getLoc (), ompReducerName,
516510 firReducer.getTypeAttr ().getValue ());
517511
518- cloneFIRRegionToOMP (firReducer.getAllocRegion (),
512+ cloneFIRRegionToOMP (rewriter, firReducer.getAllocRegion (),
519513 ompReducer.getAllocRegion ());
520- cloneFIRRegionToOMP (firReducer.getInitializerRegion (),
514+ cloneFIRRegionToOMP (rewriter, firReducer.getInitializerRegion (),
521515 ompReducer.getInitializerRegion ());
522- cloneFIRRegionToOMP (firReducer.getReductionRegion (),
516+ cloneFIRRegionToOMP (rewriter, firReducer.getReductionRegion (),
523517 ompReducer.getReductionRegion ());
524- cloneFIRRegionToOMP (firReducer.getAtomicReductionRegion (),
518+ cloneFIRRegionToOMP (rewriter, firReducer.getAtomicReductionRegion (),
525519 ompReducer.getAtomicReductionRegion ());
526- cloneFIRRegionToOMP (firReducer.getCleanupRegion (),
520+ cloneFIRRegionToOMP (rewriter, firReducer.getCleanupRegion (),
527521 ompReducer.getCleanupRegion ());
528522 moduleSymbolTable.insert (ompReducer);
529523 }
@@ -555,21 +549,10 @@ class DoConcurrentConversion
555549
556550 rewriter.setInsertionPointToEnd (&loopNestOp.getRegion ().back ());
557551 mlir::omp::YieldOp::create (rewriter, loop->getLoc ());
552+ loop->getParentOfType <mlir::ModuleOp>().print (
553+ llvm::errs (), mlir::OpPrintingFlags ().assumeVerified ());
558554
559- // `local` region arguments are transferred/cloned from the `do concurrent`
560- // loop to the loopnest op when the region is cloned above. Instead, these
561- // region arguments should be on the workshare loop's region.
562- for (auto [wsloopArg, loopNestArg] :
563- llvm::zip_equal (wsloopOp.getRegion ().getArguments (),
564- loopNestOp.getRegion ().getArguments ().drop_front (
565- clauseOps.loopLowerBounds .size ())))
566- rewriter.replaceAllUsesWith (loopNestArg, wsloopArg);
567-
568- for (unsigned i = 0 ;
569- i < loop.getLocalVars ().size () + loop.getReduceVars ().size (); ++i)
570- loopNestOp.getRegion ().eraseArgument (clauseOps.loopLowerBounds .size ());
571-
572- return loopNestOp;
555+ return {loopNestOp, wsloopOp};
573556 }
574557
575558 void genBoundsOps (fir::FirOpBuilder &builder, mlir::Value liveIn,
@@ -810,6 +793,59 @@ class DoConcurrentConversion
810793 return distOp;
811794 }
812795
796+ void cloneFIRRegionToOMP (mlir::ConversionPatternRewriter &rewriter,
797+ mlir::Region &firRegion,
798+ mlir::Region &ompRegion) const {
799+ if (!firRegion.empty ()) {
800+ rewriter.cloneRegionBefore (firRegion, ompRegion, ompRegion.begin ());
801+ auto firYield =
802+ mlir::cast<fir::YieldOp>(ompRegion.back ().getTerminator ());
803+ rewriter.setInsertionPoint (firYield);
804+ mlir::omp::YieldOp::create (rewriter, firYield.getLoc (),
805+ firYield.getOperands ());
806+ rewriter.eraseOp (firYield);
807+ }
808+ }
809+
810+ void genPrivatizers (mlir::ConversionPatternRewriter &rewriter,
811+ mlir::IRMapping &mapper, fir::DoConcurrentLoopOp loop,
812+ mlir::omp::PrivateClauseOps &privateClauseOps) const {
813+ // For `local` (and `local_init`) operands, emit corresponding `private`
814+ // clauses and attach these clauses to the workshare loop.
815+ if (!loop.getLocalVars ().empty ())
816+ for (auto [var, sym, arg] : llvm::zip_equal (
817+ loop.getLocalVars (),
818+ loop.getLocalSymsAttr ().getAsRange <mlir::SymbolRefAttr>(),
819+ loop.getRegionLocalArgs ())) {
820+ auto localizer = moduleSymbolTable.lookup <fir::LocalitySpecifierOp>(
821+ sym.getLeafReference ());
822+ if (localizer.getLocalitySpecifierType () ==
823+ fir::LocalitySpecifierType::LocalInit)
824+ TODO (localizer.getLoc (),
825+ " local_init conversion is not supported yet" );
826+
827+ mlir::OpBuilder::InsertionGuard guard (rewriter);
828+ rewriter.setInsertionPointAfter (localizer);
829+
830+ auto privatizer = mlir::omp::PrivateClauseOp::create (
831+ rewriter, localizer.getLoc (), sym.getLeafReference ().str () + " .omp" ,
832+ localizer.getTypeAttr ().getValue (),
833+ mlir::omp::DataSharingClauseType::Private);
834+
835+ cloneFIRRegionToOMP (rewriter, localizer.getInitRegion (),
836+ privatizer.getInitRegion ());
837+ cloneFIRRegionToOMP (rewriter, localizer.getDeallocRegion (),
838+ privatizer.getDeallocRegion ());
839+
840+ moduleSymbolTable.insert (privatizer);
841+
842+ privateClauseOps.privateVars .push_back (mapToDevice ? mapper.lookup (var)
843+ : var);
844+ privateClauseOps.privateSyms .push_back (
845+ mlir::SymbolRefAttr::get (privatizer));
846+ }
847+ }
848+
813849 bool mapToDevice;
814850 llvm::DenseSet<fir::DoConcurrentOp> &concurrentLoopsToSkip;
815851 mlir::SymbolTable &moduleSymbolTable;
0 commit comments