@@ -173,9 +173,11 @@ class DoConcurrentConversion
173173
174174 DoConcurrentConversion (
175175 mlir::MLIRContext *context, bool mapToDevice,
176- llvm::DenseSet<fir::DoConcurrentOp> &concurrentLoopsToSkip)
176+ llvm::DenseSet<fir::DoConcurrentOp> &concurrentLoopsToSkip,
177+ mlir::SymbolTable &moduleSymbolTable)
177178 : OpConversionPattern(context), mapToDevice(mapToDevice),
178- concurrentLoopsToSkip (concurrentLoopsToSkip) {}
179+ concurrentLoopsToSkip (concurrentLoopsToSkip),
180+ moduleSymbolTable(moduleSymbolTable) {}
179181
180182 mlir::LogicalResult
181183 matchAndRewrite (fir::DoConcurrentOp doLoop, OpAdaptor adaptor,
@@ -332,8 +334,8 @@ class DoConcurrentConversion
332334 loop.getLocalVars (),
333335 loop.getLocalSymsAttr ().getAsRange <mlir::SymbolRefAttr>(),
334336 loop.getRegionLocalArgs ())) {
335- auto localizer = mlir::SymbolTable::lookupNearestSymbolFrom<
336- fir::LocalitySpecifierOp>(loop, sym);
337+ auto localizer = moduleSymbolTable. lookup <fir::LocalitySpecifierOp>(
338+ sym. getLeafReference () );
337339 if (localizer.getLocalitySpecifierType () ==
338340 fir::LocalitySpecifierType::LocalInit)
339341 TODO (localizer.getLoc (),
@@ -352,6 +354,8 @@ class DoConcurrentConversion
352354 cloneFIRRegionToOMP (localizer.getDeallocRegion (),
353355 privatizer.getDeallocRegion ());
354356
357+ moduleSymbolTable.insert (privatizer);
358+
355359 wsloopClauseOps.privateVars .push_back (op);
356360 wsloopClauseOps.privateSyms .push_back (
357361 mlir::SymbolRefAttr::get (privatizer));
@@ -362,28 +366,34 @@ class DoConcurrentConversion
362366 loop.getReduceVars (), loop.getReduceByrefAttr ().asArrayRef (),
363367 loop.getReduceSymsAttr ().getAsRange <mlir::SymbolRefAttr>(),
364368 loop.getRegionReduceArgs ())) {
365- auto firReducer =
366- mlir::SymbolTable::lookupNearestSymbolFrom<fir::DeclareReductionOp>(
367- loop, sym);
369+ auto firReducer = moduleSymbolTable.lookup <fir::DeclareReductionOp>(
370+ sym.getLeafReference ());
368371
369372 mlir::OpBuilder::InsertionGuard guard (rewriter);
370373 rewriter.setInsertionPointAfter (firReducer);
371-
372- auto ompReducer = mlir::omp::DeclareReductionOp::create (
373- rewriter, firReducer.getLoc (),
374- sym.getLeafReference ().str () + " .omp" ,
375- firReducer.getTypeAttr ().getValue ());
376-
377- cloneFIRRegionToOMP (firReducer.getAllocRegion (),
378- ompReducer.getAllocRegion ());
379- cloneFIRRegionToOMP (firReducer.getInitializerRegion (),
380- ompReducer.getInitializerRegion ());
381- cloneFIRRegionToOMP (firReducer.getReductionRegion (),
382- ompReducer.getReductionRegion ());
383- cloneFIRRegionToOMP (firReducer.getAtomicReductionRegion (),
384- ompReducer.getAtomicReductionRegion ());
385- cloneFIRRegionToOMP (firReducer.getCleanupRegion (),
386- ompReducer.getCleanupRegion ());
374+ std::string ompReducerName = sym.getLeafReference ().str () + " .omp" ;
375+
376+ auto ompReducer =
377+ moduleSymbolTable.lookup <mlir::omp::DeclareReductionOp>(
378+ rewriter.getStringAttr (ompReducerName));
379+
380+ if (!ompReducer) {
381+ ompReducer = mlir::omp::DeclareReductionOp::create (
382+ rewriter, firReducer.getLoc (), ompReducerName,
383+ firReducer.getTypeAttr ().getValue ());
384+
385+ cloneFIRRegionToOMP (firReducer.getAllocRegion (),
386+ ompReducer.getAllocRegion ());
387+ cloneFIRRegionToOMP (firReducer.getInitializerRegion (),
388+ ompReducer.getInitializerRegion ());
389+ cloneFIRRegionToOMP (firReducer.getReductionRegion (),
390+ ompReducer.getReductionRegion ());
391+ cloneFIRRegionToOMP (firReducer.getAtomicReductionRegion (),
392+ ompReducer.getAtomicReductionRegion ());
393+ cloneFIRRegionToOMP (firReducer.getCleanupRegion (),
394+ ompReducer.getCleanupRegion ());
395+ moduleSymbolTable.insert (ompReducer);
396+ }
387397
388398 wsloopClauseOps.reductionVars .push_back (op);
389399 wsloopClauseOps.reductionByref .push_back (byRef);
@@ -431,6 +441,7 @@ class DoConcurrentConversion
431441
432442 bool mapToDevice;
433443 llvm::DenseSet<fir::DoConcurrentOp> &concurrentLoopsToSkip;
444+ mlir::SymbolTable &moduleSymbolTable;
434445};
435446
436447class DoConcurrentConversionPass
@@ -444,12 +455,9 @@ class DoConcurrentConversionPass
444455 : DoConcurrentConversionPassBase(options) {}
445456
446457 void runOnOperation () override {
447- mlir::func::FuncOp func = getOperation ();
448-
449- if (func.isDeclaration ())
450- return ;
451-
458+ mlir::ModuleOp module = getOperation ();
452459 mlir::MLIRContext *context = &getContext ();
460+ mlir::SymbolTable moduleSymbolTable (module );
453461
454462 if (mapTo != flangomp::DoConcurrentMappingKind::DCMK_Host &&
455463 mapTo != flangomp::DoConcurrentMappingKind::DCMK_Device) {
@@ -463,7 +471,7 @@ class DoConcurrentConversionPass
463471 mlir::RewritePatternSet patterns (context);
464472 patterns.insert <DoConcurrentConversion>(
465473 context, mapTo == flangomp::DoConcurrentMappingKind::DCMK_Device,
466- concurrentLoopsToSkip);
474+ concurrentLoopsToSkip, moduleSymbolTable );
467475 mlir::ConversionTarget target (*context);
468476 target.addDynamicallyLegalOp <fir::DoConcurrentOp>(
469477 [&](fir::DoConcurrentOp op) {
@@ -472,8 +480,8 @@ class DoConcurrentConversionPass
472480 target.markUnknownOpDynamicallyLegal (
473481 [](mlir::Operation *) { return true ; });
474482
475- if (mlir::failed (mlir::applyFullConversion ( getOperation (), target,
476- std::move (patterns)))) {
483+ if (mlir::failed (
484+ mlir::applyFullConversion ( module , target, std::move (patterns)))) {
477485 signalPassFailure ();
478486 }
479487 }
0 commit comments