Skip to content

Commit 3ad05db

Browse files
committed
Correctly handle nested nested loop nests to be parallelized by workshare
1 parent a5a1021 commit 3ad05db

File tree

1 file changed

+138
-118
lines changed

1 file changed

+138
-118
lines changed

flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp

Lines changed: 138 additions & 118 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,14 @@
1919
#include "mlir/Support/LLVM.h"
2020
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
2121
#include "llvm/ADT/STLExtras.h"
22+
#include "llvm/ADT/SmallVectorExtras.h"
2223
#include "llvm/ADT/iterator_range.h"
2324

25+
#include <mlir/Dialect/Arith/IR/Arith.h>
2426
#include <mlir/Dialect/OpenMP/OpenMPClauseOperands.h>
27+
#include <mlir/Dialect/SCF/IR/SCF.h>
28+
#include <mlir/IR/Visitors.h>
29+
#include <mlir/Interfaces/SideEffectInterfaces.h>
2530
#include <variant>
2631

2732
namespace flangomp {
@@ -52,90 +57,40 @@ static bool isSupportedByFirAlloca(Type ty) {
5257
return !isa<fir::ReferenceType>(ty);
5358
}
5459

55-
static bool isSafeToParallelize(Operation *op) {
56-
if (isa<fir::DeclareOp>(op))
57-
return true;
58-
59-
llvm::SmallVector<MemoryEffects::EffectInstance> effects;
60-
MemoryEffectOpInterface interface = dyn_cast<MemoryEffectOpInterface>(op);
61-
if (!interface) {
62-
return false;
63-
}
64-
interface.getEffects(effects);
65-
if (effects.empty())
66-
return true;
67-
68-
return false;
60+
static bool mustParallelizeOp(Operation *op) {
61+
return op
62+
->walk(
63+
[](omp::WorkshareLoopWrapperOp) { return WalkResult::interrupt(); })
64+
.wasInterrupted();
6965
}
7066

71-
/// Lowers workshare to a sequence of single-thread regions and parallel loops
72-
///
73-
/// For example:
74-
///
75-
/// omp.workshare {
76-
/// %a = fir.allocmem
77-
/// omp.workshare_loop_wrapper {}
78-
/// fir.call Assign %b %a
79-
/// fir.freemem %a
80-
/// }
81-
///
82-
/// becomes
83-
///
84-
/// omp.single {
85-
/// %a = fir.allocmem
86-
/// fir.store %a %tmp
87-
/// }
88-
/// %a_reloaded = fir.load %tmp
89-
/// omp.workshare_loop_wrapper {}
90-
/// omp.single {
91-
/// fir.call Assign %b %a_reloaded
92-
/// fir.freemem %a_reloaded
93-
/// }
94-
///
95-
/// Note that we allocate temporary memory for values in omp.single's which need
96-
/// to be accessed in all threads in the closest omp.parallel
97-
///
98-
/// TODO currently we need to be able to access the encompassing omp.parallel so
99-
/// that we can allocate temporaries accessible by all threads outside of it.
100-
/// In case we do not find it, we fall back to converting the omp.workshare to
101-
/// omp.single.
102-
/// To better handle this we should probably enable yielding values out of an
103-
/// omp.single which will be supported by the omp runtime.
104-
void lowerWorkshare(mlir::omp::WorkshareOp wsOp) {
105-
assert(wsOp.getRegion().getBlocks().size() == 1);
106-
107-
Location loc = wsOp->getLoc();
67+
static bool isSafeToParallelize(Operation *op) {
68+
return isa<fir::DeclareOp>(op) || isPure(op);
69+
}
10870

109-
omp::ParallelOp parallelOp = wsOp->getParentOfType<omp::ParallelOp>();
110-
if (!parallelOp) {
111-
wsOp.emitWarning("cannot handle workshare, converting to single");
112-
Operation *terminator = wsOp.getRegion().front().getTerminator();
113-
wsOp->getBlock()->getOperations().splice(
114-
wsOp->getIterator(), wsOp.getRegion().front().getOperations());
115-
terminator->erase();
116-
return;
117-
}
118-
119-
OpBuilder allocBuilder(parallelOp);
120-
OpBuilder rootBuilder(wsOp);
121-
IRMapping rootMapping;
71+
static void parallelizeRegion(Region &sourceRegion, Region &targetRegion,
72+
IRMapping &rootMapping, Location loc) {
73+
Operation *parentOp = sourceRegion.getParentOp();
74+
OpBuilder rootBuilder(sourceRegion.getContext());
12275

76+
// TODO need to copyprivate the alloca's
12377
auto mapReloadedValue = [&](Value v, OpBuilder singleBuilder,
12478
IRMapping singleMapping) {
79+
OpBuilder allocaBuilder(&targetRegion.front().front());
12580
if (auto reloaded = rootMapping.lookupOrNull(v))
12681
return;
127-
Type llvmPtrTy = LLVM::LLVMPointerType::get(allocBuilder.getContext());
82+
Type llvmPtrTy = LLVM::LLVMPointerType::get(allocaBuilder.getContext());
12883
Type ty = v.getType();
12984
Value alloc, reloaded;
13085
if (isSupportedByFirAlloca(ty)) {
131-
alloc = allocBuilder.create<fir::AllocaOp>(loc, ty);
86+
alloc = allocaBuilder.create<fir::AllocaOp>(loc, ty);
13287
singleBuilder.create<fir::StoreOp>(loc, singleMapping.lookup(v), alloc);
13388
reloaded = rootBuilder.create<fir::LoadOp>(loc, ty, alloc);
13489
} else {
135-
auto one = allocBuilder.create<LLVM::ConstantOp>(
136-
loc, allocBuilder.getI32Type(), 1);
90+
auto one = allocaBuilder.create<LLVM::ConstantOp>(
91+
loc, allocaBuilder.getI32Type(), 1);
13792
alloc =
138-
allocBuilder.create<LLVM::AllocaOp>(loc, llvmPtrTy, llvmPtrTy, one);
93+
allocaBuilder.create<LLVM::AllocaOp>(loc, llvmPtrTy, llvmPtrTy, one);
13994
Value toStore = singleBuilder
14095
.create<UnrealizedConversionCastOp>(
14196
loc, llvmPtrTy, singleMapping.lookup(v))
@@ -162,9 +117,10 @@ void lowerWorkshare(mlir::omp::WorkshareOp wsOp) {
162117
for (auto res : op.getResults()) {
163118
for (auto &use : res.getUses()) {
164119
Operation *user = use.getOwner();
165-
while (user->getParentOp() != wsOp)
120+
while (user->getParentOp() != parentOp)
166121
user = user->getParentOp();
167-
if (!user->isBeforeInBlock(&*sr.end)) {
122+
if (!(user->isBeforeInBlock(&*sr.end) &&
123+
sr.begin->isBeforeInBlock(user))) {
168124
// We need to reload
169125
mapReloadedValue(use.get(), singleBuilder, singleMapping);
170126
}
@@ -175,61 +131,125 @@ void lowerWorkshare(mlir::omp::WorkshareOp wsOp) {
175131
singleBuilder.create<omp::TerminatorOp>(loc);
176132
};
177133

178-
Block *wsBlock = &wsOp.getRegion().front();
179-
assert(wsBlock->getTerminator()->getNumOperands() == 0);
180-
Operation *terminator = wsBlock->getTerminator();
134+
// TODO Need to handle these (clone them) in dominator tree order
135+
for (Block &block : sourceRegion) {
136+
rootBuilder.createBlock(
137+
&targetRegion, {}, block.getArgumentTypes(),
138+
llvm::map_to_vector(block.getArguments(),
139+
[](BlockArgument arg) { return arg.getLoc(); }));
140+
Operation *terminator = block.getTerminator();
181141

182-
SmallVector<std::variant<SingleRegion, omp::WorkshareLoopWrapperOp>> regions;
142+
SmallVector<std::variant<SingleRegion, Operation *>> regions;
183143

184-
auto it = wsBlock->begin();
185-
auto getSingleRegion = [&]() {
186-
if (&*it == terminator)
187-
return false;
188-
if (auto pop = dyn_cast<omp::WorkshareLoopWrapperOp>(&*it)) {
189-
regions.push_back(pop);
190-
it++;
144+
auto it = block.begin();
145+
auto getOneRegion = [&]() {
146+
if (&*it == terminator)
147+
return false;
148+
if (mustParallelizeOp(&*it)) {
149+
regions.push_back(&*it);
150+
it++;
151+
return true;
152+
}
153+
SingleRegion sr;
154+
sr.begin = it;
155+
while (&*it != terminator && !mustParallelizeOp(&*it))
156+
it++;
157+
sr.end = it;
158+
assert(sr.begin != sr.end);
159+
regions.push_back(sr);
191160
return true;
161+
};
162+
while (getOneRegion())
163+
;
164+
165+
for (auto [i, opOrSingle] : llvm::enumerate(regions)) {
166+
bool isLast = i + 1 == regions.size();
167+
if (std::holds_alternative<SingleRegion>(opOrSingle)) {
168+
omp::SingleOperands singleOperands;
169+
if (isLast)
170+
singleOperands.nowait = rootBuilder.getUnitAttr();
171+
omp::SingleOp singleOp =
172+
rootBuilder.create<omp::SingleOp>(loc, singleOperands);
173+
OpBuilder singleBuilder(singleOp);
174+
singleBuilder.createBlock(&singleOp.getRegion());
175+
moveToSingle(std::get<SingleRegion>(opOrSingle), singleBuilder);
176+
} else {
177+
auto op = std::get<Operation *>(opOrSingle);
178+
if (auto wslw = dyn_cast<omp::WorkshareLoopWrapperOp>(op)) {
179+
omp::WsloopOperands wsloopOperands;
180+
if (isLast)
181+
wsloopOperands.nowait = rootBuilder.getUnitAttr();
182+
auto wsloop =
183+
rootBuilder.create<mlir::omp::WsloopOp>(loc, wsloopOperands);
184+
auto clonedWslw = cast<omp::WorkshareLoopWrapperOp>(
185+
rootBuilder.clone(*wslw, rootMapping));
186+
wsloop.getRegion().takeBody(clonedWslw.getRegion());
187+
clonedWslw->erase();
188+
} else {
189+
assert(mustParallelizeOp(op));
190+
Operation *cloned = rootBuilder.cloneWithoutRegions(*op, rootMapping);
191+
for (auto [region, clonedRegion] :
192+
llvm::zip(op->getRegions(), cloned->getRegions()))
193+
parallelizeRegion(region, clonedRegion, rootMapping, loc);
194+
}
195+
}
192196
}
193-
SingleRegion sr;
194-
sr.begin = it;
195-
while (&*it != terminator && !isa<omp::WorkshareLoopWrapperOp>(&*it))
196-
it++;
197-
sr.end = it;
198-
assert(sr.begin != sr.end);
199-
regions.push_back(sr);
200-
return true;
201-
};
202-
while (getSingleRegion())
203-
;
204-
205-
for (auto [i, loopOrSingle] : llvm::enumerate(regions)) {
206-
bool isLast = i + 1 == regions.size();
207-
if (std::holds_alternative<SingleRegion>(loopOrSingle)) {
208-
omp::SingleOperands singleOperands;
209-
if (isLast)
210-
singleOperands.nowait = rootBuilder.getUnitAttr();
211-
omp::SingleOp singleOp =
212-
rootBuilder.create<omp::SingleOp>(loc, singleOperands);
213-
OpBuilder singleBuilder(singleOp);
214-
singleBuilder.createBlock(&singleOp.getRegion());
215-
moveToSingle(std::get<SingleRegion>(loopOrSingle), singleBuilder);
216-
} else {
217-
omp::WsloopOperands wsloopOperands;
218-
if (isLast)
219-
wsloopOperands.nowait = rootBuilder.getUnitAttr();
220-
auto wsloop =
221-
rootBuilder.create<mlir::omp::WsloopOp>(loc, wsloopOperands);
222-
auto wslw = std::get<omp::WorkshareLoopWrapperOp>(loopOrSingle);
223-
auto clonedWslw = cast<omp::WorkshareLoopWrapperOp>(
224-
rootBuilder.clone(*wslw, rootMapping));
225-
wsloop.getRegion().takeBody(clonedWslw.getRegion());
226-
clonedWslw->erase();
227-
}
197+
198+
rootBuilder.clone(*block.getTerminator(), rootMapping);
228199
}
200+
}
201+
202+
/// Lowers workshare to a sequence of single-thread regions and parallel loops
203+
///
204+
/// For example:
205+
///
206+
/// omp.workshare {
207+
/// %a = fir.allocmem
208+
/// omp.workshare_loop_wrapper {}
209+
/// fir.call Assign %b %a
210+
/// fir.freemem %a
211+
/// }
212+
///
213+
/// becomes
214+
///
215+
/// omp.single {
216+
/// %a = fir.allocmem
217+
/// fir.store %a %tmp
218+
/// }
219+
/// %a_reloaded = fir.load %tmp
220+
/// omp.workshare_loop_wrapper {}
221+
/// omp.single {
222+
/// fir.call Assign %b %a_reloaded
223+
/// fir.freemem %a_reloaded
224+
/// }
225+
///
226+
/// Note that we allocate temporary memory for values in omp.single's which need
227+
/// to be accessed in all threads in the closest omp.parallel
228+
void lowerWorkshare(mlir::omp::WorkshareOp wsOp) {
229+
Location loc = wsOp->getLoc();
230+
IRMapping rootMapping;
231+
232+
OpBuilder rootBuilder(wsOp);
233+
234+
// TODO We need something like an scf;execute here, but that is not registered
235+
// so using fir.if for now but it looks like it does not support multiple
236+
// blocks so it doesnt work for multi block case...
237+
auto ifOp = rootBuilder.create<fir::IfOp>(
238+
loc, rootBuilder.create<arith::ConstantIntOp>(loc, 1, 1), false);
239+
ifOp.getThenRegion().front().erase();
240+
241+
parallelizeRegion(wsOp.getRegion(), ifOp.getThenRegion(), rootMapping, loc);
242+
243+
Operation *terminatorOp = ifOp.getThenRegion().back().getTerminator();
244+
assert(isa<omp::TerminatorOp>(terminatorOp));
245+
OpBuilder termBuilder(terminatorOp);
229246

230247
if (!wsOp.getNowait())
231-
rootBuilder.create<omp::BarrierOp>(loc);
248+
termBuilder.create<omp::BarrierOp>(loc);
249+
250+
termBuilder.create<fir::ResultOp>(loc, ValueRange());
232251

252+
terminatorOp->erase();
233253
wsOp->erase();
234254

235255
return;

0 commit comments

Comments
 (0)