Skip to content

Commit c2ae97a

Browse files
committed
Use copyprivate to scatter val from omp.single
TODO still need to implement copy function TODO transitive check for usage outside of omp.single not imiplemented yet
1 parent 0513b04 commit c2ae97a

File tree

4 files changed

+111
-34
lines changed

4 files changed

+111
-34
lines changed

flang/include/flang/Optimizer/OpenMP/Passes.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,10 @@
1414
#define FORTRAN_OPTIMIZER_OPENMP_PASSES_H
1515

1616
#include "mlir/Dialect/Func/IR/FuncOps.h"
17+
#include "mlir/IR/BuiltinOps.h"
1718
#include "mlir/Pass/Pass.h"
1819
#include "mlir/Pass/PassRegistry.h"
20+
1921
#include <memory>
2022

2123
namespace flangomp {

flang/include/flang/Optimizer/OpenMP/Passes.td

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@
1111

1212
include "mlir/Pass/PassBase.td"
1313

14-
def LowerWorkshare : Pass<"lower-workshare"> {
14+
// Needs to be scheduled on Module as we create functions in it
15+
def LowerWorkshare : Pass<"lower-workshare", "::mlir::ModuleOp"> {
1516
let summary = "Lower workshare construct";
1617
}
1718

flang/include/flang/Tools/CLOptions.inc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -345,7 +345,7 @@ inline void createHLFIRToFIRPassPipeline(
345345
pm.addPass(hlfir::createLowerHLFIRIntrinsics());
346346
pm.addPass(hlfir::createBufferizeHLFIR());
347347
pm.addPass(hlfir::createConvertHLFIRtoFIR());
348-
addNestedPassToAllTopLevelOperations(pm, flangomp::createLowerWorkshare);
348+
pm.addPass(flangomp::createLowerWorkshare());
349349
}
350350

351351
/// Create a pass pipeline for handling certain OpenMP transformations needed

flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp

Lines changed: 106 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -8,25 +8,27 @@
88
// Lower omp workshare construct.
99
//===----------------------------------------------------------------------===//
1010

11-
#include "flang/Optimizer/Dialect/FIROps.h"
12-
#include "flang/Optimizer/Dialect/FIRType.h"
13-
#include "flang/Optimizer/OpenMP/Passes.h"
14-
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
15-
#include "mlir/IR/BuiltinOps.h"
16-
#include "mlir/IR/IRMapping.h"
17-
#include "mlir/IR/OpDefinition.h"
18-
#include "mlir/IR/PatternMatch.h"
19-
#include "mlir/Support/LLVM.h"
20-
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
21-
#include "llvm/ADT/STLExtras.h"
22-
#include "llvm/ADT/SmallVectorExtras.h"
23-
#include "llvm/ADT/iterator_range.h"
24-
11+
#include <flang/Optimizer/Builder/FIRBuilder.h>
12+
#include <flang/Optimizer/Dialect/FIROps.h>
13+
#include <flang/Optimizer/Dialect/FIRType.h>
14+
#include <flang/Optimizer/HLFIR/HLFIROps.h>
15+
#include <flang/Optimizer/OpenMP/Passes.h>
16+
#include <llvm/ADT/STLExtras.h>
17+
#include <llvm/ADT/SmallVectorExtras.h>
18+
#include <llvm/ADT/iterator_range.h>
19+
#include <llvm/Support/ErrorHandling.h>
2520
#include <mlir/Dialect/Arith/IR/Arith.h>
26-
#include <mlir/Dialect/OpenMP/OpenMPClauseOperands.h>
21+
#include <mlir/Dialect/OpenMP/OpenMPDialect.h>
2722
#include <mlir/Dialect/SCF/IR/SCF.h>
23+
#include <mlir/IR/BuiltinOps.h>
24+
#include <mlir/IR/IRMapping.h>
25+
#include <mlir/IR/OpDefinition.h>
26+
#include <mlir/IR/PatternMatch.h>
2827
#include <mlir/IR/Visitors.h>
2928
#include <mlir/Interfaces/SideEffectInterfaces.h>
29+
#include <mlir/Support/LLVM.h>
30+
#include <mlir/Transforms/GreedyPatternRewriteDriver.h>
31+
3032
#include <variant>
3133

3234
namespace flangomp {
@@ -71,34 +73,66 @@ static bool isSupportedByFirAlloca(Type ty) {
7173
}
7274

7375
static bool mustParallelizeOp(Operation *op) {
76+
// TODO as in shouldUseWorkshareLowering we be careful not to pick up
77+
// workshare_loop_wrapper in nested omp.parallel ops
7478
return op
7579
->walk(
7680
[](omp::WorkshareLoopWrapperOp) { return WalkResult::interrupt(); })
7781
.wasInterrupted();
7882
}
7983

8084
static bool isSafeToParallelize(Operation *op) {
81-
return isa<fir::DeclareOp>(op) || isPure(op);
85+
return isa<hlfir::DeclareOp>(op) || isa<fir::DeclareOp>(op) ||
86+
isMemoryEffectFree(op);
87+
}
88+
89+
static mlir::func::FuncOp createCopyFunc(mlir::Location loc, mlir::Type varType,
90+
fir::FirOpBuilder builder) {
91+
mlir::ModuleOp module = builder.getModule();
92+
mlir::Type eleTy = mlir::cast<fir::ReferenceType>(varType).getEleTy();
93+
94+
std::string copyFuncName =
95+
fir::getTypeAsString(eleTy, builder.getKindMap(), "_workshare_copy");
96+
97+
if (auto decl = module.lookupSymbol<mlir::func::FuncOp>(copyFuncName))
98+
return decl;
99+
// create function
100+
mlir::OpBuilder::InsertionGuard guard(builder);
101+
mlir::OpBuilder modBuilder(module.getBodyRegion());
102+
llvm::SmallVector<mlir::Type> argsTy = {varType, varType};
103+
auto funcType = mlir::FunctionType::get(builder.getContext(), argsTy, {});
104+
mlir::func::FuncOp funcOp =
105+
modBuilder.create<mlir::func::FuncOp>(loc, copyFuncName, funcType);
106+
funcOp.setVisibility(mlir::SymbolTable::Visibility::Private);
107+
builder.createBlock(&funcOp.getRegion(), funcOp.getRegion().end(), argsTy,
108+
{loc, loc});
109+
builder.setInsertionPointToStart(&funcOp.getRegion().back());
110+
builder.create<mlir::func::ReturnOp>(loc);
111+
return funcOp;
82112
}
83113

84114
static void parallelizeRegion(Region &sourceRegion, Region &targetRegion,
85115
IRMapping &rootMapping, Location loc) {
86116
Operation *parentOp = sourceRegion.getParentOp();
87117
OpBuilder rootBuilder(sourceRegion.getContext());
88118

119+
ModuleOp m = sourceRegion.getParentOfType<ModuleOp>();
120+
OpBuilder copyFuncBuilder(m.getBodyRegion());
121+
fir::FirOpBuilder firCopyFuncBuilder(copyFuncBuilder, m);
122+
89123
// TODO need to copyprivate the alloca's
90-
auto mapReloadedValue = [&](Value v, OpBuilder singleBuilder,
91-
IRMapping singleMapping) {
92-
OpBuilder allocaBuilder(&targetRegion.front().front());
124+
auto mapReloadedValue =
125+
[&](Value v, OpBuilder allocaBuilder, OpBuilder singleBuilder,
126+
OpBuilder parallelBuilder, IRMapping singleMapping) -> Value {
93127
if (auto reloaded = rootMapping.lookupOrNull(v))
94-
return;
128+
return nullptr;
95129
Type llvmPtrTy = LLVM::LLVMPointerType::get(allocaBuilder.getContext());
96130
Type ty = v.getType();
97131
Value alloc, reloaded;
98132
if (isSupportedByFirAlloca(ty)) {
99133
alloc = allocaBuilder.create<fir::AllocaOp>(loc, ty);
100134
singleBuilder.create<fir::StoreOp>(loc, singleMapping.lookup(v), alloc);
101-
reloaded = rootBuilder.create<fir::LoadOp>(loc, ty, alloc);
135+
reloaded = parallelBuilder.create<fir::LoadOp>(loc, ty, alloc);
102136
} else {
103137
auto one = allocaBuilder.create<LLVM::ConstantOp>(
104138
loc, allocaBuilder.getI32Type(), 1);
@@ -109,21 +143,25 @@ static void parallelizeRegion(Region &sourceRegion, Region &targetRegion,
109143
loc, llvmPtrTy, singleMapping.lookup(v))
110144
.getResult(0);
111145
singleBuilder.create<LLVM::StoreOp>(loc, toStore, alloc);
112-
reloaded = rootBuilder.create<LLVM::LoadOp>(loc, llvmPtrTy, alloc);
146+
reloaded = parallelBuilder.create<LLVM::LoadOp>(loc, llvmPtrTy, alloc);
113147
reloaded =
114-
rootBuilder.create<UnrealizedConversionCastOp>(loc, ty, reloaded)
148+
parallelBuilder.create<UnrealizedConversionCastOp>(loc, ty, reloaded)
115149
.getResult(0);
116150
}
117151
rootMapping.map(v, reloaded);
152+
return alloc;
118153
};
119154

120-
auto moveToSingle = [&](SingleRegion sr, OpBuilder singleBuilder) {
155+
auto moveToSingle = [&](SingleRegion sr, OpBuilder allocaBuilder,
156+
OpBuilder singleBuilder,
157+
OpBuilder parallelBuilder) -> SmallVector<Value> {
121158
IRMapping singleMapping = rootMapping;
159+
SmallVector<Value> copyPrivate;
122160

123161
for (Operation &op : llvm::make_range(sr.begin, sr.end)) {
124162
singleBuilder.clone(op, singleMapping);
125163
if (isSafeToParallelize(&op)) {
126-
rootBuilder.clone(op, rootMapping);
164+
parallelBuilder.clone(op, rootMapping);
127165
} else {
128166
// Prepare reloaded values for results of operations that cannot be
129167
// safely parallelized and which are used after the region `sr`
@@ -132,16 +170,21 @@ static void parallelizeRegion(Region &sourceRegion, Region &targetRegion,
132170
Operation *user = use.getOwner();
133171
while (user->getParentOp() != parentOp)
134172
user = user->getParentOp();
135-
if (!(user->isBeforeInBlock(&*sr.end) &&
136-
sr.begin->isBeforeInBlock(user))) {
137-
// We need to reload
138-
mapReloadedValue(use.get(), singleBuilder, singleMapping);
173+
// TODO we need to look at transitively used vals
174+
if (true || !(user->isBeforeInBlock(&*sr.end) &&
175+
sr.begin->isBeforeInBlock(user))) {
176+
auto alloc =
177+
mapReloadedValue(use.get(), allocaBuilder, singleBuilder,
178+
parallelBuilder, singleMapping);
179+
if (alloc)
180+
copyPrivate.push_back(alloc);
139181
}
140182
}
141183
}
142184
}
143185
}
144186
singleBuilder.create<omp::TerminatorOp>(loc);
187+
return copyPrivate;
145188
};
146189

147190
// TODO Need to handle these (clone them) in dominator tree order
@@ -178,14 +221,45 @@ static void parallelizeRegion(Region &sourceRegion, Region &targetRegion,
178221
for (auto [i, opOrSingle] : llvm::enumerate(regions)) {
179222
bool isLast = i + 1 == regions.size();
180223
if (std::holds_alternative<SingleRegion>(opOrSingle)) {
224+
OpBuilder singleBuilder(sourceRegion.getContext());
225+
Block *singleBlock = new Block();
226+
singleBuilder.setInsertionPointToStart(singleBlock);
227+
228+
OpBuilder allocaBuilder(sourceRegion.getContext());
229+
Block *allocaBlock = new Block();
230+
allocaBuilder.setInsertionPointToStart(allocaBlock);
231+
232+
OpBuilder parallelBuilder(sourceRegion.getContext());
233+
Block *parallelBlock = new Block();
234+
parallelBuilder.setInsertionPointToStart(parallelBlock);
235+
181236
omp::SingleOperands singleOperands;
182237
if (isLast)
183238
singleOperands.nowait = rootBuilder.getUnitAttr();
239+
auto insPtAtSingle = rootBuilder.saveInsertionPoint();
240+
singleOperands.copyprivateVars =
241+
moveToSingle(std::get<SingleRegion>(opOrSingle), allocaBuilder,
242+
singleBuilder, parallelBuilder);
243+
for (auto var : singleOperands.copyprivateVars) {
244+
Type ty;
245+
if (auto firAlloca = var.getDefiningOp<fir::AllocaOp>()) {
246+
ty = firAlloca.getAllocatedType();
247+
} else {
248+
llvm_unreachable("unexpected");
249+
}
250+
mlir::func::FuncOp funcOp =
251+
createCopyFunc(loc, var.getType(), firCopyFuncBuilder);
252+
singleOperands.copyprivateSyms.push_back(SymbolRefAttr::get(funcOp));
253+
}
184254
omp::SingleOp singleOp =
185255
rootBuilder.create<omp::SingleOp>(loc, singleOperands);
186-
OpBuilder singleBuilder(singleOp);
187-
singleBuilder.createBlock(&singleOp.getRegion());
188-
moveToSingle(std::get<SingleRegion>(opOrSingle), singleBuilder);
256+
singleOp.getRegion().push_back(singleBlock);
257+
rootBuilder.getInsertionBlock()->getOperations().splice(
258+
rootBuilder.getInsertionPoint(), parallelBlock->getOperations());
259+
targetRegion.front().getOperations().splice(
260+
singleOp->getIterator(), allocaBlock->getOperations());
261+
delete allocaBlock;
262+
delete parallelBlock;
189263
} else {
190264
auto op = std::get<Operation *>(opOrSingle);
191265
if (auto wslw = dyn_cast<omp::WorkshareLoopWrapperOp>(op)) {

0 commit comments

Comments
 (0)