Skip to content

Commit 319705d

Browse files
authored
[flang] do concurrent: fix reduction symbol resolution when mapping to OpenMP (#155355)
Fixes #155273 This PR introduces 2 changes: 1. The `do concurrent` to OpenMP pass is now a module pass rather than a function pass. 2. Reduction ops are looked up in the parent module before being created. The benefit of using a module pass is that the same reduction operation can be used across multiple functions if the reduction type matches.
1 parent b70fc3b commit 319705d

File tree

3 files changed

+72
-32
lines changed

3 files changed

+72
-32
lines changed

flang/include/flang/Optimizer/OpenMP/Passes.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def FunctionFilteringPass : Pass<"omp-function-filtering"> {
5050
];
5151
}
5252

53-
def DoConcurrentConversionPass : Pass<"omp-do-concurrent-conversion", "mlir::func::FuncOp"> {
53+
def DoConcurrentConversionPass : Pass<"omp-do-concurrent-conversion", "mlir::ModuleOp"> {
5454
let summary = "Map `DO CONCURRENT` loops to OpenMP worksharing loops.";
5555

5656
let description = [{ This is an experimental pass to map `DO CONCURRENT` loops

flang/lib/Optimizer/OpenMP/DoConcurrentConversion.cpp

Lines changed: 39 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -173,9 +173,11 @@ class DoConcurrentConversion
173173

174174
DoConcurrentConversion(
175175
mlir::MLIRContext *context, bool mapToDevice,
176-
llvm::DenseSet<fir::DoConcurrentOp> &concurrentLoopsToSkip)
176+
llvm::DenseSet<fir::DoConcurrentOp> &concurrentLoopsToSkip,
177+
mlir::SymbolTable &moduleSymbolTable)
177178
: OpConversionPattern(context), mapToDevice(mapToDevice),
178-
concurrentLoopsToSkip(concurrentLoopsToSkip) {}
179+
concurrentLoopsToSkip(concurrentLoopsToSkip),
180+
moduleSymbolTable(moduleSymbolTable) {}
179181

180182
mlir::LogicalResult
181183
matchAndRewrite(fir::DoConcurrentOp doLoop, OpAdaptor adaptor,
@@ -332,8 +334,8 @@ class DoConcurrentConversion
332334
loop.getLocalVars(),
333335
loop.getLocalSymsAttr().getAsRange<mlir::SymbolRefAttr>(),
334336
loop.getRegionLocalArgs())) {
335-
auto localizer = mlir::SymbolTable::lookupNearestSymbolFrom<
336-
fir::LocalitySpecifierOp>(loop, sym);
337+
auto localizer = moduleSymbolTable.lookup<fir::LocalitySpecifierOp>(
338+
sym.getLeafReference());
337339
if (localizer.getLocalitySpecifierType() ==
338340
fir::LocalitySpecifierType::LocalInit)
339341
TODO(localizer.getLoc(),
@@ -352,6 +354,8 @@ class DoConcurrentConversion
352354
cloneFIRRegionToOMP(localizer.getDeallocRegion(),
353355
privatizer.getDeallocRegion());
354356

357+
moduleSymbolTable.insert(privatizer);
358+
355359
wsloopClauseOps.privateVars.push_back(op);
356360
wsloopClauseOps.privateSyms.push_back(
357361
mlir::SymbolRefAttr::get(privatizer));
@@ -362,28 +366,34 @@ class DoConcurrentConversion
362366
loop.getReduceVars(), loop.getReduceByrefAttr().asArrayRef(),
363367
loop.getReduceSymsAttr().getAsRange<mlir::SymbolRefAttr>(),
364368
loop.getRegionReduceArgs())) {
365-
auto firReducer =
366-
mlir::SymbolTable::lookupNearestSymbolFrom<fir::DeclareReductionOp>(
367-
loop, sym);
369+
auto firReducer = moduleSymbolTable.lookup<fir::DeclareReductionOp>(
370+
sym.getLeafReference());
368371

369372
mlir::OpBuilder::InsertionGuard guard(rewriter);
370373
rewriter.setInsertionPointAfter(firReducer);
371-
372-
auto ompReducer = mlir::omp::DeclareReductionOp::create(
373-
rewriter, firReducer.getLoc(),
374-
sym.getLeafReference().str() + ".omp",
375-
firReducer.getTypeAttr().getValue());
376-
377-
cloneFIRRegionToOMP(firReducer.getAllocRegion(),
378-
ompReducer.getAllocRegion());
379-
cloneFIRRegionToOMP(firReducer.getInitializerRegion(),
380-
ompReducer.getInitializerRegion());
381-
cloneFIRRegionToOMP(firReducer.getReductionRegion(),
382-
ompReducer.getReductionRegion());
383-
cloneFIRRegionToOMP(firReducer.getAtomicReductionRegion(),
384-
ompReducer.getAtomicReductionRegion());
385-
cloneFIRRegionToOMP(firReducer.getCleanupRegion(),
386-
ompReducer.getCleanupRegion());
374+
std::string ompReducerName = sym.getLeafReference().str() + ".omp";
375+
376+
auto ompReducer =
377+
moduleSymbolTable.lookup<mlir::omp::DeclareReductionOp>(
378+
rewriter.getStringAttr(ompReducerName));
379+
380+
if (!ompReducer) {
381+
ompReducer = mlir::omp::DeclareReductionOp::create(
382+
rewriter, firReducer.getLoc(), ompReducerName,
383+
firReducer.getTypeAttr().getValue());
384+
385+
cloneFIRRegionToOMP(firReducer.getAllocRegion(),
386+
ompReducer.getAllocRegion());
387+
cloneFIRRegionToOMP(firReducer.getInitializerRegion(),
388+
ompReducer.getInitializerRegion());
389+
cloneFIRRegionToOMP(firReducer.getReductionRegion(),
390+
ompReducer.getReductionRegion());
391+
cloneFIRRegionToOMP(firReducer.getAtomicReductionRegion(),
392+
ompReducer.getAtomicReductionRegion());
393+
cloneFIRRegionToOMP(firReducer.getCleanupRegion(),
394+
ompReducer.getCleanupRegion());
395+
moduleSymbolTable.insert(ompReducer);
396+
}
387397

388398
wsloopClauseOps.reductionVars.push_back(op);
389399
wsloopClauseOps.reductionByref.push_back(byRef);
@@ -431,6 +441,7 @@ class DoConcurrentConversion
431441

432442
bool mapToDevice;
433443
llvm::DenseSet<fir::DoConcurrentOp> &concurrentLoopsToSkip;
444+
mlir::SymbolTable &moduleSymbolTable;
434445
};
435446

436447
class DoConcurrentConversionPass
@@ -444,12 +455,9 @@ class DoConcurrentConversionPass
444455
: DoConcurrentConversionPassBase(options) {}
445456

446457
void runOnOperation() override {
447-
mlir::func::FuncOp func = getOperation();
448-
449-
if (func.isDeclaration())
450-
return;
451-
458+
mlir::ModuleOp module = getOperation();
452459
mlir::MLIRContext *context = &getContext();
460+
mlir::SymbolTable moduleSymbolTable(module);
453461

454462
if (mapTo != flangomp::DoConcurrentMappingKind::DCMK_Host &&
455463
mapTo != flangomp::DoConcurrentMappingKind::DCMK_Device) {
@@ -463,7 +471,7 @@ class DoConcurrentConversionPass
463471
mlir::RewritePatternSet patterns(context);
464472
patterns.insert<DoConcurrentConversion>(
465473
context, mapTo == flangomp::DoConcurrentMappingKind::DCMK_Device,
466-
concurrentLoopsToSkip);
474+
concurrentLoopsToSkip, moduleSymbolTable);
467475
mlir::ConversionTarget target(*context);
468476
target.addDynamicallyLegalOp<fir::DoConcurrentOp>(
469477
[&](fir::DoConcurrentOp op) {
@@ -472,8 +480,8 @@ class DoConcurrentConversionPass
472480
target.markUnknownOpDynamicallyLegal(
473481
[](mlir::Operation *) { return true; });
474482

475-
if (mlir::failed(mlir::applyFullConversion(getOperation(), target,
476-
std::move(patterns)))) {
483+
if (mlir::failed(
484+
mlir::applyFullConversion(module, target, std::move(patterns)))) {
477485
signalPassFailure();
478486
}
479487
}
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
! RUN: %flang_fc1 -emit-hlfir -fopenmp -fdo-concurrent-to-openmp=host %s -o - \
2+
! RUN: | FileCheck %s
3+
4+
subroutine test1(x,s,N)
5+
real :: x(N), s
6+
integer :: N
7+
do concurrent(i=1:N) reduce(+:s)
8+
s=s+x(i)
9+
end do
10+
end subroutine test1
11+
subroutine test2(x,s,N)
12+
real :: x(N), s
13+
integer :: N
14+
do concurrent(i=1:N) reduce(+:s)
15+
s=s+x(i)
16+
end do
17+
end subroutine test2
18+
19+
! CHECK: omp.declare_reduction @[[RED_SYM:.*]] : f32 init
20+
! CHECK-NOT: omp.declare_reduction
21+
22+
! CHECK-LABEL: func.func @_QPtest1
23+
! CHECK: omp.parallel {
24+
! CHECK: omp.wsloop reduction(@[[RED_SYM]] {{.*}} : !fir.ref<f32>) {
25+
! CHECK: }
26+
! CHECK: }
27+
28+
! CHECK-LABEL: func.func @_QPtest2
29+
! CHECK: omp.parallel {
30+
! CHECK: omp.wsloop reduction(@[[RED_SYM]] {{.*}} : !fir.ref<f32>) {
31+
! CHECK: }
32+
! CHECK: }

0 commit comments

Comments
 (0)