@@ -138,6 +138,9 @@ void collectLoopLiveIns(fir::DoConcurrentLoopOp loop,
138138
139139 liveIns.push_back (operand->get ());
140140 });
141+
142+ for (mlir::Value local : loop.getLocalVars ())
143+ liveIns.push_back (local);
141144}
142145
143146// / Collects values that are local to a loop: "loop-local values". A loop-local
@@ -252,8 +255,7 @@ class DoConcurrentConversion
252255 .getIsTargetDevice ();
253256
254257 mlir::omp::TargetOperands targetClauseOps;
255- genLoopNestClauseOps (doLoop.getLoc (), rewriter, loop, mapper,
256- loopNestClauseOps,
258+ genLoopNestClauseOps (doLoop.getLoc (), rewriter, loop, loopNestClauseOps,
257259 isTargetDevice ? nullptr : &targetClauseOps);
258260
259261 LiveInShapeInfoMap liveInShapeInfoMap;
@@ -275,14 +277,13 @@ class DoConcurrentConversion
275277 }
276278
277279 mlir::omp::ParallelOp parallelOp =
278- genParallelOp (doLoop. getLoc (), rewriter , ivInfos, mapper);
280+ genParallelOp (rewriter, loop , ivInfos, mapper);
279281
280282 // Only set as composite when part of `distribute parallel do`.
281283 parallelOp.setComposite (mapToDevice);
282284
283285 if (!mapToDevice)
284- genLoopNestClauseOps (doLoop.getLoc (), rewriter, loop, mapper,
285- loopNestClauseOps);
286+ genLoopNestClauseOps (doLoop.getLoc (), rewriter, loop, loopNestClauseOps);
286287
287288 for (mlir::Value local : locals)
288289 looputils::localizeLoopLocalValue (local, parallelOp.getRegion (),
@@ -291,10 +292,38 @@ class DoConcurrentConversion
291292 if (mapToDevice)
292293 genDistributeOp (doLoop.getLoc (), rewriter).setComposite (/* val=*/ true );
293294
294- mlir::omp::LoopNestOp ompLoopNest =
295+ auto [loopNestOp, wsLoopOp] =
295296 genWsLoopOp (rewriter, loop, mapper, loopNestClauseOps,
296297 /* isComposite=*/ mapToDevice);
297298
299+ // `local` region arguments are transferred/cloned from the `do concurrent`
300+ // loop to the loopnest op when the region is cloned above. Instead, these
301+ // region arguments should be on the workshare loop's region.
302+ if (mapToDevice) {
303+ for (auto [parallelArg, loopNestArg] : llvm::zip_equal (
304+ parallelOp.getRegion ().getArguments (),
305+ loopNestOp.getRegion ().getArguments ().slice (
306+ loop.getLocalOperandsStart (), loop.getNumLocalOperands ())))
307+ rewriter.replaceAllUsesWith (loopNestArg, parallelArg);
308+
309+ for (auto [wsloopArg, loopNestArg] : llvm::zip_equal (
310+ wsLoopOp.getRegion ().getArguments (),
311+ loopNestOp.getRegion ().getArguments ().slice (
312+ loop.getReduceOperandsStart (), loop.getNumReduceOperands ())))
313+ rewriter.replaceAllUsesWith (loopNestArg, wsloopArg);
314+ } else {
315+ for (auto [wsloopArg, loopNestArg] :
316+ llvm::zip_equal (wsLoopOp.getRegion ().getArguments (),
317+ loopNestOp.getRegion ().getArguments ().drop_front (
318+ loopNestClauseOps.loopLowerBounds .size ())))
319+ rewriter.replaceAllUsesWith (loopNestArg, wsloopArg);
320+ }
321+
322+ for (unsigned i = 0 ;
323+ i < loop.getLocalVars ().size () + loop.getReduceVars ().size (); ++i)
324+ loopNestOp.getRegion ().eraseArgument (
325+ loopNestClauseOps.loopLowerBounds .size ());
326+
298327 rewriter.setInsertionPoint (doLoop);
299328 fir::FirOpBuilder builder (
300329 rewriter,
@@ -315,7 +344,7 @@ class DoConcurrentConversion
315344 // Mark `unordered` loops that are not perfectly nested to be skipped from
316345 // the legality check of the `ConversionTarget` since we are not interested
317346 // in mapping them to OpenMP.
318- ompLoopNest ->walk ([&](fir::DoConcurrentOp doLoop) {
347+ loopNestOp ->walk ([&](fir::DoConcurrentOp doLoop) {
319348 concurrentLoopsToSkip.insert (doLoop);
320349 });
321350
@@ -371,11 +400,21 @@ class DoConcurrentConversion
371400 llvm::DenseMap<mlir::Value, TargetDeclareShapeCreationInfo>;
372401
373402 mlir::omp::ParallelOp
374- genParallelOp (mlir::Location loc, mlir::ConversionPatternRewriter &rewriter,
403+ genParallelOp (mlir::ConversionPatternRewriter &rewriter,
404+ fir::DoConcurrentLoopOp loop,
375405 looputils::InductionVariableInfos &ivInfos,
376406 mlir::IRMapping &mapper) const {
377- auto parallelOp = mlir::omp::ParallelOp::create (rewriter, loc);
378- rewriter.createBlock (¶llelOp.getRegion ());
407+ mlir::omp::ParallelOperands parallelOps;
408+
409+ if (mapToDevice)
410+ genPrivatizers (rewriter, mapper, loop, parallelOps);
411+
412+ mlir::Location loc = loop.getLoc ();
413+ auto parallelOp = mlir::omp::ParallelOp::create (rewriter, loc, parallelOps);
414+ Fortran::common::openmp::EntryBlockArgs parallelArgs;
415+ parallelArgs.priv .vars = parallelOps.privateVars ;
416+ Fortran::common::openmp::genEntryBlock (rewriter, parallelArgs,
417+ parallelOp.getRegion ());
379418 rewriter.setInsertionPoint (mlir::omp::TerminatorOp::create (rewriter, loc));
380419
381420 genLoopNestIndVarAllocs (rewriter, ivInfos, mapper);
@@ -412,7 +451,7 @@ class DoConcurrentConversion
412451
413452 void genLoopNestClauseOps (
414453 mlir::Location loc, mlir::ConversionPatternRewriter &rewriter,
415- fir::DoConcurrentLoopOp loop, mlir::IRMapping &mapper,
454+ fir::DoConcurrentLoopOp loop,
416455 mlir::omp::LoopNestOperands &loopNestClauseOps,
417456 mlir::omp::TargetOperands *targetClauseOps = nullptr ) const {
418457 assert (loopNestClauseOps.loopLowerBounds .empty () &&
@@ -443,59 +482,14 @@ class DoConcurrentConversion
443482 loopNestClauseOps.loopInclusive = rewriter.getUnitAttr ();
444483 }
445484
446- mlir::omp::LoopNestOp
485+ std::pair< mlir::omp::LoopNestOp, mlir::omp::WsloopOp>
447486 genWsLoopOp (mlir::ConversionPatternRewriter &rewriter,
448487 fir::DoConcurrentLoopOp loop, mlir::IRMapping &mapper,
449488 const mlir::omp::LoopNestOperands &clauseOps,
450489 bool isComposite) const {
451490 mlir::omp::WsloopOperands wsloopClauseOps;
452-
453- auto cloneFIRRegionToOMP = [&rewriter](mlir::Region &firRegion,
454- mlir::Region &ompRegion) {
455- if (!firRegion.empty ()) {
456- rewriter.cloneRegionBefore (firRegion, ompRegion, ompRegion.begin ());
457- auto firYield =
458- mlir::cast<fir::YieldOp>(ompRegion.back ().getTerminator ());
459- rewriter.setInsertionPoint (firYield);
460- mlir::omp::YieldOp::create (rewriter, firYield.getLoc (),
461- firYield.getOperands ());
462- rewriter.eraseOp (firYield);
463- }
464- };
465-
466- // For `local` (and `local_init`) opernads, emit corresponding `private`
467- // clauses and attach these clauses to the workshare loop.
468- if (!loop.getLocalVars ().empty ())
469- for (auto [op, sym, arg] : llvm::zip_equal (
470- loop.getLocalVars (),
471- loop.getLocalSymsAttr ().getAsRange <mlir::SymbolRefAttr>(),
472- loop.getRegionLocalArgs ())) {
473- auto localizer = moduleSymbolTable.lookup <fir::LocalitySpecifierOp>(
474- sym.getLeafReference ());
475- if (localizer.getLocalitySpecifierType () ==
476- fir::LocalitySpecifierType::LocalInit)
477- TODO (localizer.getLoc (),
478- " local_init conversion is not supported yet" );
479-
480- mlir::OpBuilder::InsertionGuard guard (rewriter);
481- rewriter.setInsertionPointAfter (localizer);
482-
483- auto privatizer = mlir::omp::PrivateClauseOp::create (
484- rewriter, localizer.getLoc (), sym.getLeafReference ().str () + " .omp" ,
485- localizer.getTypeAttr ().getValue (),
486- mlir::omp::DataSharingClauseType::Private);
487-
488- cloneFIRRegionToOMP (localizer.getInitRegion (),
489- privatizer.getInitRegion ());
490- cloneFIRRegionToOMP (localizer.getDeallocRegion (),
491- privatizer.getDeallocRegion ());
492-
493- moduleSymbolTable.insert (privatizer);
494-
495- wsloopClauseOps.privateVars .push_back (op);
496- wsloopClauseOps.privateSyms .push_back (
497- mlir::SymbolRefAttr::get (privatizer));
498- }
491+ if (!mapToDevice)
492+ genPrivatizers (rewriter, mapper, loop, wsloopClauseOps);
499493
500494 if (!loop.getReduceVars ().empty ()) {
501495 for (auto [op, byRef, sym, arg] : llvm::zip_equal (
@@ -518,15 +512,15 @@ class DoConcurrentConversion
518512 rewriter, firReducer.getLoc (), ompReducerName,
519513 firReducer.getTypeAttr ().getValue ());
520514
521- cloneFIRRegionToOMP (firReducer.getAllocRegion (),
515+ cloneFIRRegionToOMP (rewriter, firReducer.getAllocRegion (),
522516 ompReducer.getAllocRegion ());
523- cloneFIRRegionToOMP (firReducer.getInitializerRegion (),
517+ cloneFIRRegionToOMP (rewriter, firReducer.getInitializerRegion (),
524518 ompReducer.getInitializerRegion ());
525- cloneFIRRegionToOMP (firReducer.getReductionRegion (),
519+ cloneFIRRegionToOMP (rewriter, firReducer.getReductionRegion (),
526520 ompReducer.getReductionRegion ());
527- cloneFIRRegionToOMP (firReducer.getAtomicReductionRegion (),
521+ cloneFIRRegionToOMP (rewriter, firReducer.getAtomicReductionRegion (),
528522 ompReducer.getAtomicReductionRegion ());
529- cloneFIRRegionToOMP (firReducer.getCleanupRegion (),
523+ cloneFIRRegionToOMP (rewriter, firReducer.getCleanupRegion (),
530524 ompReducer.getCleanupRegion ());
531525 moduleSymbolTable.insert (ompReducer);
532526 }
@@ -558,21 +552,10 @@ class DoConcurrentConversion
558552
559553 rewriter.setInsertionPointToEnd (&loopNestOp.getRegion ().back ());
560554 mlir::omp::YieldOp::create (rewriter, loop->getLoc ());
555+ loop->getParentOfType <mlir::ModuleOp>().print (
556+ llvm::errs (), mlir::OpPrintingFlags ().assumeVerified ());
561557
562- // `local` region arguments are transferred/cloned from the `do concurrent`
563- // loop to the loopnest op when the region is cloned above. Instead, these
564- // region arguments should be on the workshare loop's region.
565- for (auto [wsloopArg, loopNestArg] :
566- llvm::zip_equal (wsloopOp.getRegion ().getArguments (),
567- loopNestOp.getRegion ().getArguments ().drop_front (
568- clauseOps.loopLowerBounds .size ())))
569- rewriter.replaceAllUsesWith (loopNestArg, wsloopArg);
570-
571- for (unsigned i = 0 ;
572- i < loop.getLocalVars ().size () + loop.getReduceVars ().size (); ++i)
573- loopNestOp.getRegion ().eraseArgument (clauseOps.loopLowerBounds .size ());
574-
575- return loopNestOp;
558+ return {loopNestOp, wsloopOp};
576559 }
577560
578561 void genBoundsOps (fir::FirOpBuilder &builder, mlir::Value liveIn,
@@ -813,6 +796,59 @@ class DoConcurrentConversion
813796 return distOp;
814797 }
815798
799+ void cloneFIRRegionToOMP (mlir::ConversionPatternRewriter &rewriter,
800+ mlir::Region &firRegion,
801+ mlir::Region &ompRegion) const {
802+ if (!firRegion.empty ()) {
803+ rewriter.cloneRegionBefore (firRegion, ompRegion, ompRegion.begin ());
804+ auto firYield =
805+ mlir::cast<fir::YieldOp>(ompRegion.back ().getTerminator ());
806+ rewriter.setInsertionPoint (firYield);
807+ mlir::omp::YieldOp::create (rewriter, firYield.getLoc (),
808+ firYield.getOperands ());
809+ rewriter.eraseOp (firYield);
810+ }
811+ }
812+
813+ void genPrivatizers (mlir::ConversionPatternRewriter &rewriter,
814+ mlir::IRMapping &mapper, fir::DoConcurrentLoopOp loop,
815+ mlir::omp::PrivateClauseOps &privateClauseOps) const {
816+ // For `local` (and `local_init`) operands, emit corresponding `private`
817+ // clauses and attach these clauses to the workshare loop.
818+ if (!loop.getLocalVars ().empty ())
819+ for (auto [var, sym, arg] : llvm::zip_equal (
820+ loop.getLocalVars (),
821+ loop.getLocalSymsAttr ().getAsRange <mlir::SymbolRefAttr>(),
822+ loop.getRegionLocalArgs ())) {
823+ auto localizer = moduleSymbolTable.lookup <fir::LocalitySpecifierOp>(
824+ sym.getLeafReference ());
825+ if (localizer.getLocalitySpecifierType () ==
826+ fir::LocalitySpecifierType::LocalInit)
827+ TODO (localizer.getLoc (),
828+ " local_init conversion is not supported yet" );
829+
830+ mlir::OpBuilder::InsertionGuard guard (rewriter);
831+ rewriter.setInsertionPointAfter (localizer);
832+
833+ auto privatizer = mlir::omp::PrivateClauseOp::create (
834+ rewriter, localizer.getLoc (), sym.getLeafReference ().str () + " .omp" ,
835+ localizer.getTypeAttr ().getValue (),
836+ mlir::omp::DataSharingClauseType::Private);
837+
838+ cloneFIRRegionToOMP (rewriter, localizer.getInitRegion (),
839+ privatizer.getInitRegion ());
840+ cloneFIRRegionToOMP (rewriter, localizer.getDeallocRegion (),
841+ privatizer.getDeallocRegion ());
842+
843+ moduleSymbolTable.insert (privatizer);
844+
845+ privateClauseOps.privateVars .push_back (mapToDevice ? mapper.lookup (var)
846+ : var);
847+ privateClauseOps.privateSyms .push_back (
848+ mlir::SymbolRefAttr::get (privatizer));
849+ }
850+ }
851+
816852 bool mapToDevice;
817853 llvm::DenseSet<fir::DoConcurrentOp> &concurrentLoopsToSkip;
818854 mlir::SymbolTable &moduleSymbolTable;
0 commit comments