Skip to content

Commit 30e1294

Browse files
committed
[flang] Lower omp.workshare to other omp constructs
1 parent a9f4b77 commit 30e1294

File tree

7 files changed

+391
-1
lines changed

7 files changed

+391
-1
lines changed

flang/include/flang/Optimizer/OpenMP/Passes.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ namespace flangomp {
2525
#define GEN_PASS_REGISTRATION
2626
#include "flang/Optimizer/OpenMP/Passes.h.inc"
2727

28+
bool shouldUseWorkshareLowering(mlir::Operation *op);
29+
2830
} // namespace flangomp
2931

3032
#endif // FORTRAN_OPTIMIZER_OPENMP_PASSES_H

flang/include/flang/Optimizer/OpenMP/Passes.td

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
//===-- Passes.td - HLFIR pass definition file -------------*- tablegen -*-===//
1+
//===-- Passes.td - flang OpenMP pass definition -----------*- tablegen -*-===//
22
//
33
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
44
// See https://llvm.org/LICENSE.txt for license information.
@@ -37,4 +37,8 @@ def FunctionFiltering : Pass<"omp-function-filtering"> {
3737
];
3838
}
3939

40+
def LowerWorkshare : Pass<"lower-workshare"> {
41+
let summary = "Lower workshare construct";
42+
}
43+
4044
#endif //FORTRAN_OPTIMIZER_OPENMP_PASSES

flang/include/flang/Tools/CLOptions.inc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -345,6 +345,7 @@ inline void createHLFIRToFIRPassPipeline(
345345
pm.addPass(hlfir::createLowerHLFIRIntrinsics());
346346
pm.addPass(hlfir::createBufferizeHLFIR());
347347
pm.addPass(hlfir::createConvertHLFIRtoFIR());
348+
pm.addPass(flangomp::createLowerWorkshare());
348349
}
349350

350351
/// Create a pass pipeline for handling certain OpenMP transformations needed

flang/lib/Optimizer/OpenMP/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ add_flang_library(FlangOpenMPTransforms
44
FunctionFiltering.cpp
55
MapInfoFinalization.cpp
66
MarkDeclareTarget.cpp
7+
LowerWorkshare.cpp
78

89
DEPENDS
910
FIRDialect
Lines changed: 259 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,259 @@
1+
//===- LowerWorkshare.cpp - special cases for bufferization -------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
// Lower omp workshare construct.
9+
//===----------------------------------------------------------------------===//
10+
11+
#include "flang/Optimizer/Dialect/FIROps.h"
12+
#include "flang/Optimizer/Dialect/FIRType.h"
13+
#include "flang/Optimizer/OpenMP/Passes.h"
14+
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
15+
#include "mlir/IR/BuiltinOps.h"
16+
#include "mlir/IR/IRMapping.h"
17+
#include "mlir/IR/OpDefinition.h"
18+
#include "mlir/IR/PatternMatch.h"
19+
#include "mlir/Support/LLVM.h"
20+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
21+
#include "llvm/ADT/STLExtras.h"
22+
#include "llvm/ADT/iterator_range.h"
23+
24+
#include <variant>
25+
26+
namespace flangomp {
27+
#define GEN_PASS_DEF_LOWERWORKSHARE
28+
#include "flang/Optimizer/OpenMP/Passes.h.inc"
29+
} // namespace flangomp
30+
31+
#define DEBUG_TYPE "lower-workshare"
32+
33+
using namespace mlir;
34+
35+
namespace flangomp {
36+
bool shouldUseWorkshareLowering(Operation *op) {
37+
auto workshare = dyn_cast<omp::WorkshareOp>(op->getParentOp());
38+
if (!workshare)
39+
return false;
40+
return workshare->getParentOfType<omp::ParallelOp>();
41+
}
42+
} // namespace flangomp
43+
44+
namespace {
45+
46+
struct SingleRegion {
47+
Block::iterator begin, end;
48+
};
49+
50+
static bool isSupportedByFirAlloca(Type ty) {
51+
return !isa<fir::ReferenceType>(ty);
52+
}
53+
54+
static bool isSafeToParallelize(Operation *op) {
55+
if (isa<fir::DeclareOp>(op))
56+
return true;
57+
58+
llvm::SmallVector<MemoryEffects::EffectInstance> effects;
59+
MemoryEffectOpInterface interface = dyn_cast<MemoryEffectOpInterface>(op);
60+
if (!interface) {
61+
return false;
62+
}
63+
interface.getEffects(effects);
64+
if (effects.empty())
65+
return true;
66+
67+
return false;
68+
}
69+
70+
/// Lowers workshare to a sequence of single-thread regions and parallel loops
71+
///
72+
/// For example:
73+
///
74+
/// omp.workshare {
75+
/// %a = fir.allocmem
76+
/// omp.wsloop {}
77+
/// fir.call Assign %b %a
78+
/// fir.freemem %a
79+
/// }
80+
///
81+
/// becomes
82+
///
83+
/// omp.single {
84+
/// %a = fir.allocmem
85+
/// fir.store %a %tmp
86+
/// }
87+
/// %a_reloaded = fir.load %tmp
88+
/// omp.wsloop {}
89+
/// omp.single {
90+
/// fir.call Assign %b %a_reloaded
91+
/// fir.freemem %a_reloaded
92+
/// }
93+
///
94+
/// Note that we allocate temporary memory for values in omp.single's which need
95+
/// to be accessed in all threads in the closest omp.parallel
96+
///
97+
/// TODO currently we need to be able to access the encompassing omp.parallel so
98+
/// that we can allocate temporaries accessible by all threads outside of it.
99+
/// In case we do not find it, we fall back to converting the omp.workshare to
100+
/// omp.single.
101+
/// To better handle this we should probably enable yielding values out of an
102+
/// omp.single which will be supported by the omp runtime.
103+
void lowerWorkshare(mlir::omp::WorkshareOp wsOp) {
104+
assert(wsOp.getRegion().getBlocks().size() == 1);
105+
106+
Location loc = wsOp->getLoc();
107+
108+
omp::ParallelOp parallelOp = wsOp->getParentOfType<omp::ParallelOp>();
109+
if (!parallelOp) {
110+
wsOp.emitWarning("cannot handle workshare, converting to single");
111+
Operation *terminator = wsOp.getRegion().front().getTerminator();
112+
wsOp->getBlock()->getOperations().splice(
113+
wsOp->getIterator(), wsOp.getRegion().front().getOperations());
114+
terminator->erase();
115+
return;
116+
}
117+
118+
OpBuilder allocBuilder(parallelOp);
119+
OpBuilder rootBuilder(wsOp);
120+
IRMapping rootMapping;
121+
122+
omp::SingleOp singleOp = nullptr;
123+
124+
auto mapReloadedValue = [&](Value v, OpBuilder singleBuilder,
125+
IRMapping singleMapping) {
126+
if (auto reloaded = rootMapping.lookupOrNull(v))
127+
return;
128+
Type llvmPtrTy = LLVM::LLVMPointerType::get(allocBuilder.getContext());
129+
Type ty = v.getType();
130+
Value alloc, reloaded;
131+
if (isSupportedByFirAlloca(ty)) {
132+
alloc = allocBuilder.create<fir::AllocaOp>(loc, ty);
133+
singleBuilder.create<fir::StoreOp>(loc, singleMapping.lookup(v), alloc);
134+
reloaded = rootBuilder.create<fir::LoadOp>(loc, ty, alloc);
135+
} else {
136+
auto one = allocBuilder.create<LLVM::ConstantOp>(
137+
loc, allocBuilder.getI32Type(), 1);
138+
alloc =
139+
allocBuilder.create<LLVM::AllocaOp>(loc, llvmPtrTy, llvmPtrTy, one);
140+
Value toStore = singleBuilder
141+
.create<UnrealizedConversionCastOp>(
142+
loc, llvmPtrTy, singleMapping.lookup(v))
143+
.getResult(0);
144+
singleBuilder.create<LLVM::StoreOp>(loc, toStore, alloc);
145+
reloaded = rootBuilder.create<LLVM::LoadOp>(loc, llvmPtrTy, alloc);
146+
reloaded =
147+
rootBuilder.create<UnrealizedConversionCastOp>(loc, ty, reloaded)
148+
.getResult(0);
149+
}
150+
rootMapping.map(v, reloaded);
151+
};
152+
153+
auto moveToSingle = [&](SingleRegion sr, OpBuilder singleBuilder) {
154+
IRMapping singleMapping = rootMapping;
155+
156+
for (Operation &op : llvm::make_range(sr.begin, sr.end)) {
157+
singleBuilder.clone(op, singleMapping);
158+
if (isSafeToParallelize(&op)) {
159+
rootBuilder.clone(op, rootMapping);
160+
} else {
161+
// Prepare reloaded values for results of operations that cannot be
162+
// safely parallelized and which are used after the region `sr`
163+
for (auto res : op.getResults()) {
164+
for (auto &use : res.getUses()) {
165+
Operation *user = use.getOwner();
166+
while (user->getParentOp() != wsOp)
167+
user = user->getParentOp();
168+
if (!user->isBeforeInBlock(&*sr.end)) {
169+
// We need to reload
170+
mapReloadedValue(use.get(), singleBuilder, singleMapping);
171+
}
172+
}
173+
}
174+
}
175+
}
176+
singleBuilder.create<omp::TerminatorOp>(loc);
177+
};
178+
179+
Block *wsBlock = &wsOp.getRegion().front();
180+
assert(wsBlock->getTerminator()->getNumOperands() == 0);
181+
Operation *terminator = wsBlock->getTerminator();
182+
183+
SmallVector<std::variant<SingleRegion, omp::WsloopOp>> regions;
184+
185+
auto it = wsBlock->begin();
186+
auto getSingleRegion = [&]() {
187+
if (&*it == terminator)
188+
return false;
189+
if (auto pop = dyn_cast<omp::WsloopOp>(&*it)) {
190+
regions.push_back(pop);
191+
it++;
192+
return true;
193+
}
194+
SingleRegion sr;
195+
sr.begin = it;
196+
while (&*it != terminator && !isa<omp::WsloopOp>(&*it))
197+
it++;
198+
sr.end = it;
199+
assert(sr.begin != sr.end);
200+
regions.push_back(sr);
201+
return true;
202+
};
203+
while (getSingleRegion())
204+
;
205+
206+
for (auto [i, loopOrSingle] : llvm::enumerate(regions)) {
207+
bool isLast = i + 1 == regions.size();
208+
if (std::holds_alternative<SingleRegion>(loopOrSingle)) {
209+
omp::SingleOperands singleOperands;
210+
if (isLast)
211+
singleOperands.nowait = rootBuilder.getUnitAttr();
212+
singleOp = rootBuilder.create<omp::SingleOp>(loc, singleOperands);
213+
OpBuilder singleBuilder(singleOp);
214+
singleBuilder.createBlock(&singleOp.getRegion());
215+
moveToSingle(std::get<SingleRegion>(loopOrSingle), singleBuilder);
216+
} else {
217+
rootBuilder.clone(*std::get<omp::WsloopOp>(loopOrSingle), rootMapping);
218+
if (!isLast)
219+
rootBuilder.create<omp::BarrierOp>(loc);
220+
}
221+
}
222+
223+
if (!wsOp.getNowait())
224+
rootBuilder.create<omp::BarrierOp>(loc);
225+
226+
wsOp->erase();
227+
228+
return;
229+
}
230+
231+
class LowerWorksharePass
232+
: public flangomp::impl::LowerWorkshareBase<LowerWorksharePass> {
233+
public:
234+
void runOnOperation() override {
235+
SmallPtrSet<Operation *, 8> parents;
236+
getOperation()->walk([&](mlir::omp::WorkshareOp wsOp) {
237+
Operation *isolatedParent =
238+
wsOp->getParentWithTrait<OpTrait::IsIsolatedFromAbove>();
239+
parents.insert(isolatedParent);
240+
241+
lowerWorkshare(wsOp);
242+
});
243+
244+
// Do folding
245+
for (Operation *isolatedParent : parents) {
246+
RewritePatternSet patterns(&getContext());
247+
GreedyRewriteConfig config;
248+
// prevent the pattern driver form merging blocks
249+
config.enableRegionSimplification =
250+
mlir::GreedySimplifyRegionLevel::Disabled;
251+
if (failed(applyPatternsAndFoldGreedily(isolatedParent,
252+
std::move(patterns), config))) {
253+
emitError(isolatedParent->getLoc(), "error in lower workshare\n");
254+
signalPassFailure();
255+
}
256+
}
257+
}
258+
};
259+
} // namespace
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
// RUN: fir-opt --lower-workshare %s | FileCheck %s
2+
3+
module {
4+
// CHECK-LABEL: func.func @simple(
5+
// CHECK-SAME: %[[VAL_0:.*]]: !fir.ref<!fir.array<42xi32>>) {
6+
// CHECK: %[[VAL_1:.*]] = arith.constant 1 : index
7+
// CHECK: %[[VAL_2:.*]] = arith.constant 1 : i32
8+
// CHECK: %[[VAL_3:.*]] = arith.constant 42 : index
9+
// CHECK: %[[VAL_4:.*]] = llvm.mlir.constant(1 : i32) : i32
10+
// CHECK: %[[VAL_5:.*]] = llvm.alloca %[[VAL_4]] x !llvm.ptr : (i32) -> !llvm.ptr
11+
// CHECK: %[[VAL_6:.*]] = fir.alloca !fir.heap<!fir.array<42xi32>>
12+
// CHECK: omp.parallel {
13+
// CHECK: omp.single {
14+
// CHECK: %[[VAL_7:.*]] = fir.shape %[[VAL_3]] : (index) -> !fir.shape<1>
15+
// CHECK: %[[VAL_8:.*]]:2 = hlfir.declare %[[VAL_0]](%[[VAL_7]]) {uniq_name = "array"} : (!fir.ref<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>)
16+
// CHECK: %[[VAL_9:.*]] = builtin.unrealized_conversion_cast %[[VAL_8]]#0 : !fir.ref<!fir.array<42xi32>> to !llvm.ptr
17+
// CHECK: llvm.store %[[VAL_9]], %[[VAL_5]] : !llvm.ptr, !llvm.ptr
18+
// CHECK: %[[VAL_10:.*]] = fir.allocmem !fir.array<42xi32> {bindc_name = ".tmp.array", uniq_name = ""}
19+
// CHECK: %[[VAL_11:.*]]:2 = hlfir.declare %[[VAL_10]](%[[VAL_7]]) {uniq_name = ".tmp.array"} : (!fir.heap<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.heap<!fir.array<42xi32>>, !fir.heap<!fir.array<42xi32>>)
20+
// CHECK: fir.store %[[VAL_11]]#0 to %[[VAL_6]] : !fir.ref<!fir.heap<!fir.array<42xi32>>>
21+
// CHECK: omp.terminator
22+
// CHECK: }
23+
// CHECK: %[[VAL_12:.*]] = llvm.load %[[VAL_5]] : !llvm.ptr -> !llvm.ptr
24+
// CHECK: %[[VAL_13:.*]] = builtin.unrealized_conversion_cast %[[VAL_12]] : !llvm.ptr to !fir.ref<!fir.array<42xi32>>
25+
// CHECK: %[[VAL_14:.*]] = fir.load %[[VAL_6]] : !fir.ref<!fir.heap<!fir.array<42xi32>>>
26+
// CHECK: omp.wsloop {
27+
// CHECK: omp.loop_nest (%[[VAL_15:.*]]) : index = (%[[VAL_1]]) to (%[[VAL_3]]) inclusive step (%[[VAL_1]]) {
28+
// CHECK: %[[VAL_16:.*]] = hlfir.designate %[[VAL_13]] (%[[VAL_15]]) : (!fir.ref<!fir.array<42xi32>>, index) -> !fir.ref<i32>
29+
// CHECK: %[[VAL_17:.*]] = fir.load %[[VAL_16]] : !fir.ref<i32>
30+
// CHECK: %[[VAL_18:.*]] = arith.subi %[[VAL_17]], %[[VAL_2]] : i32
31+
// CHECK: %[[VAL_19:.*]] = hlfir.designate %[[VAL_14]] (%[[VAL_15]]) : (!fir.heap<!fir.array<42xi32>>, index) -> !fir.ref<i32>
32+
// CHECK: hlfir.assign %[[VAL_18]] to %[[VAL_19]] temporary_lhs : i32, !fir.ref<i32>
33+
// CHECK: omp.yield
34+
// CHECK: }
35+
// CHECK: omp.terminator
36+
// CHECK: }
37+
// CHECK: omp.barrier
38+
// CHECK: omp.single nowait {
39+
// CHECK: hlfir.assign %[[VAL_14]] to %[[VAL_13]] : !fir.heap<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>
40+
// CHECK: fir.freemem %[[VAL_14]] : !fir.heap<!fir.array<42xi32>>
41+
// CHECK: omp.terminator
42+
// CHECK: }
43+
// CHECK: omp.barrier
44+
// CHECK: omp.terminator
45+
// CHECK: }
46+
// CHECK: return
47+
// CHECK: }
48+
func.func @simple(%arg0: !fir.ref<!fir.array<42xi32>>) {
49+
omp.parallel {
50+
omp.workshare {
51+
%c42 = arith.constant 42 : index
52+
%c1_i32 = arith.constant 1 : i32
53+
%0 = fir.shape %c42 : (index) -> !fir.shape<1>
54+
%1:2 = hlfir.declare %arg0(%0) {uniq_name = "array"} : (!fir.ref<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>)
55+
%2 = fir.allocmem !fir.array<42xi32> {bindc_name = ".tmp.array", uniq_name = ""}
56+
%3:2 = hlfir.declare %2(%0) {uniq_name = ".tmp.array"} : (!fir.heap<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.heap<!fir.array<42xi32>>, !fir.heap<!fir.array<42xi32>>)
57+
%true = arith.constant true
58+
%c1 = arith.constant 1 : index
59+
omp.wsloop {
60+
omp.loop_nest (%arg1) : index = (%c1) to (%c42) inclusive step (%c1) {
61+
%7 = hlfir.designate %1#0 (%arg1) : (!fir.ref<!fir.array<42xi32>>, index) -> !fir.ref<i32>
62+
%8 = fir.load %7 : !fir.ref<i32>
63+
%9 = arith.subi %8, %c1_i32 : i32
64+
%10 = hlfir.designate %3#0 (%arg1) : (!fir.heap<!fir.array<42xi32>>, index) -> !fir.ref<i32>
65+
hlfir.assign %9 to %10 temporary_lhs : i32, !fir.ref<i32>
66+
omp.yield
67+
}
68+
omp.terminator
69+
}
70+
%4 = fir.undefined tuple<!fir.heap<!fir.array<42xi32>>, i1>
71+
%5 = fir.insert_value %4, %true, [1 : index] : (tuple<!fir.heap<!fir.array<42xi32>>, i1>, i1) -> tuple<!fir.heap<!fir.array<42xi32>>, i1>
72+
%6 = fir.insert_value %5, %3#0, [0 : index] : (tuple<!fir.heap<!fir.array<42xi32>>, i1>, !fir.heap<!fir.array<42xi32>>) -> tuple<!fir.heap<!fir.array<42xi32>>, i1>
73+
hlfir.assign %3#0 to %1#0 : !fir.heap<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>
74+
fir.freemem %3#0 : !fir.heap<!fir.array<42xi32>>
75+
omp.terminator
76+
}
77+
omp.terminator
78+
}
79+
return
80+
}
81+
}

0 commit comments

Comments
 (0)