Skip to content

Commit 38dffb1

Browse files
committed
[flang] Lower omp.workshare to other omp constructs
1 parent a14789a commit 38dffb1

File tree

16 files changed

+528
-0
lines changed

16 files changed

+528
-0
lines changed

flang/include/flang/Optimizer/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,4 @@ add_subdirectory(CodeGen)
22
add_subdirectory(Dialect)
33
add_subdirectory(HLFIR)
44
add_subdirectory(Transforms)
5+
add_subdirectory(OpenMP)
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
set(LLVM_TARGET_DEFINITIONS Passes.td)
2+
mlir_tablegen(Passes.h.inc -gen-pass-decls -name FlangOpenMP)
3+
4+
add_public_tablegen_target(FlangOpenMPPassesIncGen)
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
//===- Passes.h - OpenMP pass entry points ----------------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This header declares OpenMP pass entry points.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#ifndef FORTRAN_OPTIMIZER_OPENMP_PASSES_H
14+
#define FORTRAN_OPTIMIZER_OPENMP_PASSES_H
15+
16+
#include "mlir/Dialect/Func/IR/FuncOps.h"
17+
#include "mlir/Pass/Pass.h"
18+
#include "mlir/Pass/PassRegistry.h"
19+
#include <memory>
20+
21+
namespace flangomp {
22+
#define GEN_PASS_DECL
23+
#define GEN_PASS_REGISTRATION
24+
#include "flang/Optimizer/OpenMP/Passes.h.inc"
25+
26+
bool shouldUseWorkshareLowering(mlir::Operation *op);
27+
28+
} // namespace flangomp
29+
30+
#endif // FORTRAN_OPTIMIZER_OPENMP_PASSES_H
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
//===-- Passes.td - HLFIR pass definition file -------------*- tablegen -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef FORTRAN_DIALECT_OPENMP_PASSES
10+
#define FORTRAN_DIALECT_OPENMP_PASSES
11+
12+
include "mlir/Pass/PassBase.td"
13+
14+
def LowerWorkshare : Pass<"lower-workshare"> {
15+
let summary = "Lower workshare construct";
16+
}
17+
18+
#endif //FORTRAN_DIALECT_OPENMP_PASSES

flang/include/flang/Tools/CLOptions.inc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include "mlir/Transforms/Passes.h"
1818
#include "flang/Optimizer/CodeGen/CodeGen.h"
1919
#include "flang/Optimizer/HLFIR/Passes.h"
20+
#include "flang/Optimizer/OpenMP/Passes.h"
2021
#include "flang/Optimizer/Transforms/Passes.h"
2122
#include "llvm/Passes/OptimizationLevel.h"
2223
#include "llvm/Support/CommandLine.h"
@@ -344,6 +345,7 @@ inline void createHLFIRToFIRPassPipeline(
344345
pm.addPass(hlfir::createLowerHLFIRIntrinsics());
345346
pm.addPass(hlfir::createBufferizeHLFIR());
346347
pm.addPass(hlfir::createConvertHLFIRtoFIR());
348+
pm.addPass(flangomp::createLowerWorkshare());
347349
}
348350

349351
/// Create a pass pipeline for handling certain OpenMP transformations needed

flang/lib/Frontend/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ add_flang_library(flangFrontend
3838
FIRTransforms
3939
HLFIRDialect
4040
HLFIRTransforms
41+
FlangOpenMPTransforms
4142
MLIRTransforms
4243
MLIRBuiltinToLLVMIRTranslation
4344
MLIRLLVMToLLVMIRTranslation

flang/lib/Optimizer/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,4 @@ add_subdirectory(HLFIR)
55
add_subdirectory(Support)
66
add_subdirectory(Transforms)
77
add_subdirectory(Analysis)
8+
add_subdirectory(OpenMP)
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)
2+
3+
add_flang_library(FlangOpenMPTransforms
4+
LowerWorkshare.cpp
5+
6+
DEPENDS
7+
FIRDialect
8+
FlangOpenMPPassesIncGen
9+
${dialect_libs}
10+
11+
LINK_LIBS
12+
FIRAnalysis
13+
FIRDialect
14+
FIRBuilder
15+
FIRDialectSupport
16+
FIRSupport
17+
FIRTransforms
18+
HLFIRDialect
19+
MLIRIR
20+
${dialect_libs}
21+
22+
LINK_COMPONENTS
23+
AsmParser
24+
AsmPrinter
25+
Remarks
26+
)
Lines changed: 259 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,259 @@
1+
//===- LowerWorkshare.cpp - special cases for bufferization -------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
// Lower omp workshare construct.
9+
//===----------------------------------------------------------------------===//
10+
11+
#include "flang/Optimizer/Dialect/FIROps.h"
12+
#include "flang/Optimizer/Dialect/FIRType.h"
13+
#include "flang/Optimizer/OpenMP/Passes.h"
14+
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
15+
#include "mlir/IR/BuiltinOps.h"
16+
#include "mlir/IR/IRMapping.h"
17+
#include "mlir/IR/OpDefinition.h"
18+
#include "mlir/IR/PatternMatch.h"
19+
#include "mlir/Support/LLVM.h"
20+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
21+
#include "llvm/ADT/STLExtras.h"
22+
#include "llvm/ADT/iterator_range.h"
23+
24+
#include <variant>
25+
26+
namespace flangomp {
27+
#define GEN_PASS_DEF_LOWERWORKSHARE
28+
#include "flang/Optimizer/OpenMP/Passes.h.inc"
29+
} // namespace flangomp
30+
31+
#define DEBUG_TYPE "lower-workshare"
32+
33+
using namespace mlir;
34+
35+
namespace flangomp {
36+
bool shouldUseWorkshareLowering(Operation *op) {
37+
auto workshare = dyn_cast<omp::WorkshareOp>(op->getParentOp());
38+
if (!workshare)
39+
return false;
40+
return workshare->getParentOfType<omp::ParallelOp>();
41+
}
42+
} // namespace flangomp
43+
44+
namespace {
45+
46+
struct SingleRegion {
47+
Block::iterator begin, end;
48+
};
49+
50+
static bool isSupportedByFirAlloca(Type ty) {
51+
return !isa<fir::ReferenceType>(ty);
52+
}
53+
54+
static bool isSafeToParallelize(Operation *op) {
55+
if (isa<fir::DeclareOp>(op))
56+
return true;
57+
58+
llvm::SmallVector<MemoryEffects::EffectInstance> effects;
59+
MemoryEffectOpInterface interface = dyn_cast<MemoryEffectOpInterface>(op);
60+
if (!interface) {
61+
return false;
62+
}
63+
interface.getEffects(effects);
64+
if (effects.empty())
65+
return true;
66+
67+
return false;
68+
}
69+
70+
/// Lowers workshare to a sequence of single-thread regions and parallel loops
71+
///
72+
/// For example:
73+
///
74+
/// omp.workshare {
75+
/// %a = fir.allocmem
76+
/// omp.wsloop {}
77+
/// fir.call Assign %b %a
78+
/// fir.freemem %a
79+
/// }
80+
///
81+
/// becomes
82+
///
83+
/// omp.single {
84+
/// %a = fir.allocmem
85+
/// fir.store %a %tmp
86+
/// }
87+
/// %a_reloaded = fir.load %tmp
88+
/// omp.wsloop {}
89+
/// omp.single {
90+
/// fir.call Assign %b %a_reloaded
91+
/// fir.freemem %a_reloaded
92+
/// }
93+
///
94+
/// Note that we allocate temporary memory for values in omp.single's which need
95+
/// to be accessed in all threads in the closest omp.parallel
96+
///
97+
/// TODO currently we need to be able to access the encompassing omp.parallel so
98+
/// that we can allocate temporaries accessible by all threads outside of it.
99+
/// In case we do not find it, we fall back to converting the omp.workshare to
100+
/// omp.single.
101+
/// To better handle this we should probably enable yielding values out of an
102+
/// omp.single which will be supported by the omp runtime.
103+
void lowerWorkshare(mlir::omp::WorkshareOp wsOp) {
104+
assert(wsOp.getRegion().getBlocks().size() == 1);
105+
106+
Location loc = wsOp->getLoc();
107+
108+
omp::ParallelOp parallelOp = wsOp->getParentOfType<omp::ParallelOp>();
109+
if (!parallelOp) {
110+
wsOp.emitWarning("cannot handle workshare, converting to single");
111+
Operation *terminator = wsOp.getRegion().front().getTerminator();
112+
wsOp->getBlock()->getOperations().splice(
113+
wsOp->getIterator(), wsOp.getRegion().front().getOperations());
114+
terminator->erase();
115+
return;
116+
}
117+
118+
OpBuilder allocBuilder(parallelOp);
119+
OpBuilder rootBuilder(wsOp);
120+
IRMapping rootMapping;
121+
122+
omp::SingleOp singleOp = nullptr;
123+
124+
auto mapReloadedValue = [&](Value v, OpBuilder singleBuilder,
125+
IRMapping singleMapping) {
126+
if (auto reloaded = rootMapping.lookupOrNull(v))
127+
return;
128+
Type llvmPtrTy = LLVM::LLVMPointerType::get(allocBuilder.getContext());
129+
Type ty = v.getType();
130+
Value alloc, reloaded;
131+
if (isSupportedByFirAlloca(ty)) {
132+
alloc = allocBuilder.create<fir::AllocaOp>(loc, ty);
133+
singleBuilder.create<fir::StoreOp>(loc, singleMapping.lookup(v), alloc);
134+
reloaded = rootBuilder.create<fir::LoadOp>(loc, ty, alloc);
135+
} else {
136+
auto one = allocBuilder.create<LLVM::ConstantOp>(
137+
loc, allocBuilder.getI32Type(), 1);
138+
alloc =
139+
allocBuilder.create<LLVM::AllocaOp>(loc, llvmPtrTy, llvmPtrTy, one);
140+
Value toStore = singleBuilder
141+
.create<UnrealizedConversionCastOp>(
142+
loc, llvmPtrTy, singleMapping.lookup(v))
143+
.getResult(0);
144+
singleBuilder.create<LLVM::StoreOp>(loc, toStore, alloc);
145+
reloaded = rootBuilder.create<LLVM::LoadOp>(loc, llvmPtrTy, alloc);
146+
reloaded =
147+
rootBuilder.create<UnrealizedConversionCastOp>(loc, ty, reloaded)
148+
.getResult(0);
149+
}
150+
rootMapping.map(v, reloaded);
151+
};
152+
153+
auto moveToSingle = [&](SingleRegion sr, OpBuilder singleBuilder) {
154+
IRMapping singleMapping = rootMapping;
155+
156+
for (Operation &op : llvm::make_range(sr.begin, sr.end)) {
157+
singleBuilder.clone(op, singleMapping);
158+
if (isSafeToParallelize(&op)) {
159+
rootBuilder.clone(op, rootMapping);
160+
} else {
161+
// Prepare reloaded values for results of operations that cannot be
162+
// safely parallelized and which are used after the region `sr`
163+
for (auto res : op.getResults()) {
164+
for (auto &use : res.getUses()) {
165+
Operation *user = use.getOwner();
166+
while (user->getParentOp() != wsOp)
167+
user = user->getParentOp();
168+
if (!user->isBeforeInBlock(&*sr.end)) {
169+
// We need to reload
170+
mapReloadedValue(use.get(), singleBuilder, singleMapping);
171+
}
172+
}
173+
}
174+
}
175+
}
176+
singleBuilder.create<omp::TerminatorOp>(loc);
177+
};
178+
179+
Block *wsBlock = &wsOp.getRegion().front();
180+
assert(wsBlock->getTerminator()->getNumOperands() == 0);
181+
Operation *terminator = wsBlock->getTerminator();
182+
183+
SmallVector<std::variant<SingleRegion, omp::WsloopOp>> regions;
184+
185+
auto it = wsBlock->begin();
186+
auto getSingleRegion = [&]() {
187+
if (&*it == terminator)
188+
return false;
189+
if (auto pop = dyn_cast<omp::WsloopOp>(&*it)) {
190+
regions.push_back(pop);
191+
it++;
192+
return true;
193+
}
194+
SingleRegion sr;
195+
sr.begin = it;
196+
while (&*it != terminator && !isa<omp::WsloopOp>(&*it))
197+
it++;
198+
sr.end = it;
199+
assert(sr.begin != sr.end);
200+
regions.push_back(sr);
201+
return true;
202+
};
203+
while (getSingleRegion())
204+
;
205+
206+
for (auto [i, loopOrSingle] : llvm::enumerate(regions)) {
207+
bool isLast = i + 1 == regions.size();
208+
if (std::holds_alternative<SingleRegion>(loopOrSingle)) {
209+
omp::SingleOperands singleOperands;
210+
if (isLast)
211+
singleOperands.nowait = rootBuilder.getUnitAttr();
212+
singleOp = rootBuilder.create<omp::SingleOp>(loc, singleOperands);
213+
OpBuilder singleBuilder(singleOp);
214+
singleBuilder.createBlock(&singleOp.getRegion());
215+
moveToSingle(std::get<SingleRegion>(loopOrSingle), singleBuilder);
216+
} else {
217+
rootBuilder.clone(*std::get<omp::WsloopOp>(loopOrSingle), rootMapping);
218+
if (!isLast)
219+
rootBuilder.create<omp::BarrierOp>(loc);
220+
}
221+
}
222+
223+
if (!wsOp.getNowait())
224+
rootBuilder.create<omp::BarrierOp>(loc);
225+
226+
wsOp->erase();
227+
228+
return;
229+
}
230+
231+
class LowerWorksharePass
232+
: public flangomp::impl::LowerWorkshareBase<LowerWorksharePass> {
233+
public:
234+
void runOnOperation() override {
235+
SmallPtrSet<Operation *, 8> parents;
236+
getOperation()->walk([&](mlir::omp::WorkshareOp wsOp) {
237+
Operation *isolatedParent =
238+
wsOp->getParentWithTrait<OpTrait::IsIsolatedFromAbove>();
239+
parents.insert(isolatedParent);
240+
241+
lowerWorkshare(wsOp);
242+
});
243+
244+
// Do folding
245+
for (Operation *isolatedParent : parents) {
246+
RewritePatternSet patterns(&getContext());
247+
GreedyRewriteConfig config;
248+
// prevent the pattern driver form merging blocks
249+
config.enableRegionSimplification =
250+
mlir::GreedySimplifyRegionLevel::Disabled;
251+
if (failed(applyPatternsAndFoldGreedily(isolatedParent,
252+
std::move(patterns), config))) {
253+
emitError(isolatedParent->getLoc(), "error in lower workshare\n");
254+
signalPassFailure();
255+
}
256+
}
257+
}
258+
};
259+
} // namespace

0 commit comments

Comments
 (0)