@@ -137,6 +137,9 @@ void collectLoopLiveIns(fir::DoConcurrentLoopOp loop,
137
137
138
138
liveIns.push_back (operand->get ());
139
139
});
140
+
141
+ for (mlir::Value local : loop.getLocalVars ())
142
+ liveIns.push_back (local);
140
143
}
141
144
142
145
// / Collects values that are local to a loop: "loop-local values". A loop-local
@@ -251,8 +254,7 @@ class DoConcurrentConversion
251
254
.getIsTargetDevice ();
252
255
253
256
mlir::omp::TargetOperands targetClauseOps;
254
- genLoopNestClauseOps (doLoop.getLoc (), rewriter, loop, mapper,
255
- loopNestClauseOps,
257
+ genLoopNestClauseOps (doLoop.getLoc (), rewriter, loop, loopNestClauseOps,
256
258
isTargetDevice ? nullptr : &targetClauseOps);
257
259
258
260
LiveInShapeInfoMap liveInShapeInfoMap;
@@ -274,14 +276,13 @@ class DoConcurrentConversion
274
276
}
275
277
276
278
mlir::omp::ParallelOp parallelOp =
277
- genParallelOp (doLoop. getLoc (), rewriter , ivInfos, mapper);
279
+ genParallelOp (rewriter, loop , ivInfos, mapper);
278
280
279
281
// Only set as composite when part of `distribute parallel do`.
280
282
parallelOp.setComposite (mapToDevice);
281
283
282
284
if (!mapToDevice)
283
- genLoopNestClauseOps (doLoop.getLoc (), rewriter, loop, mapper,
284
- loopNestClauseOps);
285
+ genLoopNestClauseOps (doLoop.getLoc (), rewriter, loop, loopNestClauseOps);
285
286
286
287
for (mlir::Value local : locals)
287
288
looputils::localizeLoopLocalValue (local, parallelOp.getRegion (),
@@ -290,10 +291,38 @@ class DoConcurrentConversion
290
291
if (mapToDevice)
291
292
genDistributeOp (doLoop.getLoc (), rewriter).setComposite (/* val=*/ true );
292
293
293
- mlir::omp::LoopNestOp ompLoopNest =
294
+ auto [loopNestOp, wsLoopOp] =
294
295
genWsLoopOp (rewriter, loop, mapper, loopNestClauseOps,
295
296
/* isComposite=*/ mapToDevice);
296
297
298
+ // `local` region arguments are transferred/cloned from the `do concurrent`
299
+ // loop to the loopnest op when the region is cloned above. Instead, these
300
+ // region arguments should be on the workshare loop's region.
301
+ if (mapToDevice) {
302
+ for (auto [parallelArg, loopNestArg] : llvm::zip_equal (
303
+ parallelOp.getRegion ().getArguments (),
304
+ loopNestOp.getRegion ().getArguments ().slice (
305
+ loop.getLocalOperandsStart (), loop.getNumLocalOperands ())))
306
+ rewriter.replaceAllUsesWith (loopNestArg, parallelArg);
307
+
308
+ for (auto [wsloopArg, loopNestArg] : llvm::zip_equal (
309
+ wsLoopOp.getRegion ().getArguments (),
310
+ loopNestOp.getRegion ().getArguments ().slice (
311
+ loop.getReduceOperandsStart (), loop.getNumReduceOperands ())))
312
+ rewriter.replaceAllUsesWith (loopNestArg, wsloopArg);
313
+ } else {
314
+ for (auto [wsloopArg, loopNestArg] :
315
+ llvm::zip_equal (wsLoopOp.getRegion ().getArguments (),
316
+ loopNestOp.getRegion ().getArguments ().drop_front (
317
+ loopNestClauseOps.loopLowerBounds .size ())))
318
+ rewriter.replaceAllUsesWith (loopNestArg, wsloopArg);
319
+ }
320
+
321
+ for (unsigned i = 0 ;
322
+ i < loop.getLocalVars ().size () + loop.getReduceVars ().size (); ++i)
323
+ loopNestOp.getRegion ().eraseArgument (
324
+ loopNestClauseOps.loopLowerBounds .size ());
325
+
297
326
rewriter.setInsertionPoint (doLoop);
298
327
fir::FirOpBuilder builder (
299
328
rewriter,
@@ -314,7 +343,7 @@ class DoConcurrentConversion
314
343
// Mark `unordered` loops that are not perfectly nested to be skipped from
315
344
// the legality check of the `ConversionTarget` since we are not interested
316
345
// in mapping them to OpenMP.
317
- ompLoopNest ->walk ([&](fir::DoConcurrentOp doLoop) {
346
+ loopNestOp ->walk ([&](fir::DoConcurrentOp doLoop) {
318
347
concurrentLoopsToSkip.insert (doLoop);
319
348
});
320
349
@@ -370,11 +399,21 @@ class DoConcurrentConversion
370
399
llvm::DenseMap<mlir::Value, TargetDeclareShapeCreationInfo>;
371
400
372
401
mlir::omp::ParallelOp
373
- genParallelOp (mlir::Location loc, mlir::ConversionPatternRewriter &rewriter,
402
+ genParallelOp (mlir::ConversionPatternRewriter &rewriter,
403
+ fir::DoConcurrentLoopOp loop,
374
404
looputils::InductionVariableInfos &ivInfos,
375
405
mlir::IRMapping &mapper) const {
376
- auto parallelOp = mlir::omp::ParallelOp::create (rewriter, loc);
377
- rewriter.createBlock (¶llelOp.getRegion ());
406
+ mlir::omp::ParallelOperands parallelOps;
407
+
408
+ if (mapToDevice)
409
+ genPrivatizers (rewriter, mapper, loop, parallelOps);
410
+
411
+ mlir::Location loc = loop.getLoc ();
412
+ auto parallelOp = mlir::omp::ParallelOp::create (rewriter, loc, parallelOps);
413
+ Fortran::common::openmp::EntryBlockArgs parallelArgs;
414
+ parallelArgs.priv .vars = parallelOps.privateVars ;
415
+ Fortran::common::openmp::genEntryBlock (rewriter, parallelArgs,
416
+ parallelOp.getRegion ());
378
417
rewriter.setInsertionPoint (mlir::omp::TerminatorOp::create (rewriter, loc));
379
418
380
419
genLoopNestIndVarAllocs (rewriter, ivInfos, mapper);
@@ -411,7 +450,7 @@ class DoConcurrentConversion
411
450
412
451
void genLoopNestClauseOps (
413
452
mlir::Location loc, mlir::ConversionPatternRewriter &rewriter,
414
- fir::DoConcurrentLoopOp loop, mlir::IRMapping &mapper,
453
+ fir::DoConcurrentLoopOp loop,
415
454
mlir::omp::LoopNestOperands &loopNestClauseOps,
416
455
mlir::omp::TargetOperands *targetClauseOps = nullptr ) const {
417
456
assert (loopNestClauseOps.loopLowerBounds .empty () &&
@@ -440,59 +479,14 @@ class DoConcurrentConversion
440
479
loopNestClauseOps.loopInclusive = rewriter.getUnitAttr ();
441
480
}
442
481
443
- mlir::omp::LoopNestOp
482
+ std::pair< mlir::omp::LoopNestOp, mlir::omp::WsloopOp>
444
483
genWsLoopOp (mlir::ConversionPatternRewriter &rewriter,
445
484
fir::DoConcurrentLoopOp loop, mlir::IRMapping &mapper,
446
485
const mlir::omp::LoopNestOperands &clauseOps,
447
486
bool isComposite) const {
448
487
mlir::omp::WsloopOperands wsloopClauseOps;
449
-
450
- auto cloneFIRRegionToOMP = [&rewriter](mlir::Region &firRegion,
451
- mlir::Region &ompRegion) {
452
- if (!firRegion.empty ()) {
453
- rewriter.cloneRegionBefore (firRegion, ompRegion, ompRegion.begin ());
454
- auto firYield =
455
- mlir::cast<fir::YieldOp>(ompRegion.back ().getTerminator ());
456
- rewriter.setInsertionPoint (firYield);
457
- mlir::omp::YieldOp::create (rewriter, firYield.getLoc (),
458
- firYield.getOperands ());
459
- rewriter.eraseOp (firYield);
460
- }
461
- };
462
-
463
- // For `local` (and `local_init`) opernads, emit corresponding `private`
464
- // clauses and attach these clauses to the workshare loop.
465
- if (!loop.getLocalVars ().empty ())
466
- for (auto [op, sym, arg] : llvm::zip_equal (
467
- loop.getLocalVars (),
468
- loop.getLocalSymsAttr ().getAsRange <mlir::SymbolRefAttr>(),
469
- loop.getRegionLocalArgs ())) {
470
- auto localizer = moduleSymbolTable.lookup <fir::LocalitySpecifierOp>(
471
- sym.getLeafReference ());
472
- if (localizer.getLocalitySpecifierType () ==
473
- fir::LocalitySpecifierType::LocalInit)
474
- TODO (localizer.getLoc (),
475
- " local_init conversion is not supported yet" );
476
-
477
- mlir::OpBuilder::InsertionGuard guard (rewriter);
478
- rewriter.setInsertionPointAfter (localizer);
479
-
480
- auto privatizer = mlir::omp::PrivateClauseOp::create (
481
- rewriter, localizer.getLoc (), sym.getLeafReference ().str () + " .omp" ,
482
- localizer.getTypeAttr ().getValue (),
483
- mlir::omp::DataSharingClauseType::Private);
484
-
485
- cloneFIRRegionToOMP (localizer.getInitRegion (),
486
- privatizer.getInitRegion ());
487
- cloneFIRRegionToOMP (localizer.getDeallocRegion (),
488
- privatizer.getDeallocRegion ());
489
-
490
- moduleSymbolTable.insert (privatizer);
491
-
492
- wsloopClauseOps.privateVars .push_back (op);
493
- wsloopClauseOps.privateSyms .push_back (
494
- mlir::SymbolRefAttr::get (privatizer));
495
- }
488
+ if (!mapToDevice)
489
+ genPrivatizers (rewriter, mapper, loop, wsloopClauseOps);
496
490
497
491
if (!loop.getReduceVars ().empty ()) {
498
492
for (auto [op, byRef, sym, arg] : llvm::zip_equal (
@@ -515,15 +509,15 @@ class DoConcurrentConversion
515
509
rewriter, firReducer.getLoc (), ompReducerName,
516
510
firReducer.getTypeAttr ().getValue ());
517
511
518
- cloneFIRRegionToOMP (firReducer.getAllocRegion (),
512
+ cloneFIRRegionToOMP (rewriter, firReducer.getAllocRegion (),
519
513
ompReducer.getAllocRegion ());
520
- cloneFIRRegionToOMP (firReducer.getInitializerRegion (),
514
+ cloneFIRRegionToOMP (rewriter, firReducer.getInitializerRegion (),
521
515
ompReducer.getInitializerRegion ());
522
- cloneFIRRegionToOMP (firReducer.getReductionRegion (),
516
+ cloneFIRRegionToOMP (rewriter, firReducer.getReductionRegion (),
523
517
ompReducer.getReductionRegion ());
524
- cloneFIRRegionToOMP (firReducer.getAtomicReductionRegion (),
518
+ cloneFIRRegionToOMP (rewriter, firReducer.getAtomicReductionRegion (),
525
519
ompReducer.getAtomicReductionRegion ());
526
- cloneFIRRegionToOMP (firReducer.getCleanupRegion (),
520
+ cloneFIRRegionToOMP (rewriter, firReducer.getCleanupRegion (),
527
521
ompReducer.getCleanupRegion ());
528
522
moduleSymbolTable.insert (ompReducer);
529
523
}
@@ -555,21 +549,10 @@ class DoConcurrentConversion
555
549
556
550
rewriter.setInsertionPointToEnd (&loopNestOp.getRegion ().back ());
557
551
mlir::omp::YieldOp::create (rewriter, loop->getLoc ());
552
+ loop->getParentOfType <mlir::ModuleOp>().print (
553
+ llvm::errs (), mlir::OpPrintingFlags ().assumeVerified ());
558
554
559
- // `local` region arguments are transferred/cloned from the `do concurrent`
560
- // loop to the loopnest op when the region is cloned above. Instead, these
561
- // region arguments should be on the workshare loop's region.
562
- for (auto [wsloopArg, loopNestArg] :
563
- llvm::zip_equal (wsloopOp.getRegion ().getArguments (),
564
- loopNestOp.getRegion ().getArguments ().drop_front (
565
- clauseOps.loopLowerBounds .size ())))
566
- rewriter.replaceAllUsesWith (loopNestArg, wsloopArg);
567
-
568
- for (unsigned i = 0 ;
569
- i < loop.getLocalVars ().size () + loop.getReduceVars ().size (); ++i)
570
- loopNestOp.getRegion ().eraseArgument (clauseOps.loopLowerBounds .size ());
571
-
572
- return loopNestOp;
555
+ return {loopNestOp, wsloopOp};
573
556
}
574
557
575
558
void genBoundsOps (fir::FirOpBuilder &builder, mlir::Value liveIn,
@@ -810,6 +793,59 @@ class DoConcurrentConversion
810
793
return distOp;
811
794
}
812
795
796
+ void cloneFIRRegionToOMP (mlir::ConversionPatternRewriter &rewriter,
797
+ mlir::Region &firRegion,
798
+ mlir::Region &ompRegion) const {
799
+ if (!firRegion.empty ()) {
800
+ rewriter.cloneRegionBefore (firRegion, ompRegion, ompRegion.begin ());
801
+ auto firYield =
802
+ mlir::cast<fir::YieldOp>(ompRegion.back ().getTerminator ());
803
+ rewriter.setInsertionPoint (firYield);
804
+ mlir::omp::YieldOp::create (rewriter, firYield.getLoc (),
805
+ firYield.getOperands ());
806
+ rewriter.eraseOp (firYield);
807
+ }
808
+ }
809
+
810
+ void genPrivatizers (mlir::ConversionPatternRewriter &rewriter,
811
+ mlir::IRMapping &mapper, fir::DoConcurrentLoopOp loop,
812
+ mlir::omp::PrivateClauseOps &privateClauseOps) const {
813
+ // For `local` (and `local_init`) operands, emit corresponding `private`
814
+ // clauses and attach these clauses to the workshare loop.
815
+ if (!loop.getLocalVars ().empty ())
816
+ for (auto [var, sym, arg] : llvm::zip_equal (
817
+ loop.getLocalVars (),
818
+ loop.getLocalSymsAttr ().getAsRange <mlir::SymbolRefAttr>(),
819
+ loop.getRegionLocalArgs ())) {
820
+ auto localizer = moduleSymbolTable.lookup <fir::LocalitySpecifierOp>(
821
+ sym.getLeafReference ());
822
+ if (localizer.getLocalitySpecifierType () ==
823
+ fir::LocalitySpecifierType::LocalInit)
824
+ TODO (localizer.getLoc (),
825
+ " local_init conversion is not supported yet" );
826
+
827
+ mlir::OpBuilder::InsertionGuard guard (rewriter);
828
+ rewriter.setInsertionPointAfter (localizer);
829
+
830
+ auto privatizer = mlir::omp::PrivateClauseOp::create (
831
+ rewriter, localizer.getLoc (), sym.getLeafReference ().str () + " .omp" ,
832
+ localizer.getTypeAttr ().getValue (),
833
+ mlir::omp::DataSharingClauseType::Private);
834
+
835
+ cloneFIRRegionToOMP (rewriter, localizer.getInitRegion (),
836
+ privatizer.getInitRegion ());
837
+ cloneFIRRegionToOMP (rewriter, localizer.getDeallocRegion (),
838
+ privatizer.getDeallocRegion ());
839
+
840
+ moduleSymbolTable.insert (privatizer);
841
+
842
+ privateClauseOps.privateVars .push_back (mapToDevice ? mapper.lookup (var)
843
+ : var);
844
+ privateClauseOps.privateSyms .push_back (
845
+ mlir::SymbolRefAttr::get (privatizer));
846
+ }
847
+ }
848
+
813
849
bool mapToDevice;
814
850
llvm::DenseSet<fir::DoConcurrentOp> &concurrentLoopsToSkip;
815
851
mlir::SymbolTable &moduleSymbolTable;
0 commit comments