8
8
// Lower omp workshare construct.
9
9
// ===----------------------------------------------------------------------===//
10
10
11
- #include " flang/Optimizer/Dialect/FIROps.h"
12
- #include " flang/Optimizer/Dialect/FIRType.h"
13
- #include " flang/Optimizer/OpenMP/Passes.h"
14
- #include " mlir/Dialect/OpenMP/OpenMPDialect.h"
15
- #include " mlir/IR/BuiltinOps.h"
16
- #include " mlir/IR/IRMapping.h"
17
- #include " mlir/IR/OpDefinition.h"
18
- #include " mlir/IR/PatternMatch.h"
19
- #include " mlir/Support/LLVM.h"
20
- #include " mlir/Transforms/GreedyPatternRewriteDriver.h"
21
- #include " llvm/ADT/STLExtras.h"
22
- #include " llvm/ADT/SmallVectorExtras.h"
23
- #include " llvm/ADT/iterator_range.h"
24
-
11
+ #include < flang/Optimizer/Builder/FIRBuilder.h>
12
+ #include < flang/Optimizer/Dialect/FIROps.h>
13
+ #include < flang/Optimizer/Dialect/FIRType.h>
14
+ #include < flang/Optimizer/HLFIR/HLFIROps.h>
15
+ #include < flang/Optimizer/OpenMP/Passes.h>
16
+ #include < llvm/ADT/STLExtras.h>
17
+ #include < llvm/ADT/SmallVectorExtras.h>
18
+ #include < llvm/ADT/iterator_range.h>
19
+ #include < llvm/Support/ErrorHandling.h>
25
20
#include < mlir/Dialect/Arith/IR/Arith.h>
26
- #include < mlir/Dialect/OpenMP/OpenMPClauseOperands .h>
21
+ #include < mlir/Dialect/OpenMP/OpenMPDialect .h>
27
22
#include < mlir/Dialect/SCF/IR/SCF.h>
23
+ #include < mlir/IR/BuiltinOps.h>
24
+ #include < mlir/IR/IRMapping.h>
25
+ #include < mlir/IR/OpDefinition.h>
26
+ #include < mlir/IR/PatternMatch.h>
28
27
#include < mlir/IR/Visitors.h>
29
28
#include < mlir/Interfaces/SideEffectInterfaces.h>
29
+ #include < mlir/Support/LLVM.h>
30
+ #include < mlir/Transforms/GreedyPatternRewriteDriver.h>
31
+
30
32
#include < variant>
31
33
32
34
namespace flangomp {
@@ -71,34 +73,66 @@ static bool isSupportedByFirAlloca(Type ty) {
71
73
}
72
74
73
75
static bool mustParallelizeOp (Operation *op) {
76
+ // TODO as in shouldUseWorkshareLowering we be careful not to pick up
77
+ // workshare_loop_wrapper in nested omp.parallel ops
74
78
return op
75
79
->walk (
76
80
[](omp::WorkshareLoopWrapperOp) { return WalkResult::interrupt (); })
77
81
.wasInterrupted ();
78
82
}
79
83
80
84
static bool isSafeToParallelize (Operation *op) {
81
- return isa<fir::DeclareOp>(op) || isPure (op);
85
+ return isa<hlfir::DeclareOp>(op) || isa<fir::DeclareOp>(op) ||
86
+ isMemoryEffectFree (op);
87
+ }
88
+
89
+ static mlir::func::FuncOp createCopyFunc (mlir::Location loc, mlir::Type varType,
90
+ fir::FirOpBuilder builder) {
91
+ mlir::ModuleOp module = builder.getModule ();
92
+ mlir::Type eleTy = mlir::cast<fir::ReferenceType>(varType).getEleTy ();
93
+
94
+ std::string copyFuncName =
95
+ fir::getTypeAsString (eleTy, builder.getKindMap (), " _workshare_copy" );
96
+
97
+ if (auto decl = module .lookupSymbol <mlir::func::FuncOp>(copyFuncName))
98
+ return decl;
99
+ // create function
100
+ mlir::OpBuilder::InsertionGuard guard (builder);
101
+ mlir::OpBuilder modBuilder (module .getBodyRegion ());
102
+ llvm::SmallVector<mlir::Type> argsTy = {varType, varType};
103
+ auto funcType = mlir::FunctionType::get (builder.getContext (), argsTy, {});
104
+ mlir::func::FuncOp funcOp =
105
+ modBuilder.create <mlir::func::FuncOp>(loc, copyFuncName, funcType);
106
+ funcOp.setVisibility (mlir::SymbolTable::Visibility::Private);
107
+ builder.createBlock (&funcOp.getRegion (), funcOp.getRegion ().end (), argsTy,
108
+ {loc, loc});
109
+ builder.setInsertionPointToStart (&funcOp.getRegion ().back ());
110
+ builder.create <mlir::func::ReturnOp>(loc);
111
+ return funcOp;
82
112
}
83
113
84
114
static void parallelizeRegion (Region &sourceRegion, Region &targetRegion,
85
115
IRMapping &rootMapping, Location loc) {
86
116
Operation *parentOp = sourceRegion.getParentOp ();
87
117
OpBuilder rootBuilder (sourceRegion.getContext ());
88
118
119
+ ModuleOp m = sourceRegion.getParentOfType <ModuleOp>();
120
+ OpBuilder copyFuncBuilder (m.getBodyRegion ());
121
+ fir::FirOpBuilder firCopyFuncBuilder (copyFuncBuilder, m);
122
+
89
123
// TODO need to copyprivate the alloca's
90
- auto mapReloadedValue = [&](Value v, OpBuilder singleBuilder,
91
- IRMapping singleMapping) {
92
- OpBuilder allocaBuilder (&targetRegion. front (). front ());
124
+ auto mapReloadedValue =
125
+ [&](Value v, OpBuilder allocaBuilder, OpBuilder singleBuilder,
126
+ OpBuilder parallelBuilder, IRMapping singleMapping) -> Value {
93
127
if (auto reloaded = rootMapping.lookupOrNull (v))
94
- return ;
128
+ return nullptr ;
95
129
Type llvmPtrTy = LLVM::LLVMPointerType::get (allocaBuilder.getContext ());
96
130
Type ty = v.getType ();
97
131
Value alloc, reloaded;
98
132
if (isSupportedByFirAlloca (ty)) {
99
133
alloc = allocaBuilder.create <fir::AllocaOp>(loc, ty);
100
134
singleBuilder.create <fir::StoreOp>(loc, singleMapping.lookup (v), alloc);
101
- reloaded = rootBuilder .create <fir::LoadOp>(loc, ty, alloc);
135
+ reloaded = parallelBuilder .create <fir::LoadOp>(loc, ty, alloc);
102
136
} else {
103
137
auto one = allocaBuilder.create <LLVM::ConstantOp>(
104
138
loc, allocaBuilder.getI32Type (), 1 );
@@ -109,21 +143,25 @@ static void parallelizeRegion(Region &sourceRegion, Region &targetRegion,
109
143
loc, llvmPtrTy, singleMapping.lookup (v))
110
144
.getResult (0 );
111
145
singleBuilder.create <LLVM::StoreOp>(loc, toStore, alloc);
112
- reloaded = rootBuilder .create <LLVM::LoadOp>(loc, llvmPtrTy, alloc);
146
+ reloaded = parallelBuilder .create <LLVM::LoadOp>(loc, llvmPtrTy, alloc);
113
147
reloaded =
114
- rootBuilder .create <UnrealizedConversionCastOp>(loc, ty, reloaded)
148
+ parallelBuilder .create <UnrealizedConversionCastOp>(loc, ty, reloaded)
115
149
.getResult (0 );
116
150
}
117
151
rootMapping.map (v, reloaded);
152
+ return alloc;
118
153
};
119
154
120
- auto moveToSingle = [&](SingleRegion sr, OpBuilder singleBuilder) {
155
+ auto moveToSingle = [&](SingleRegion sr, OpBuilder allocaBuilder,
156
+ OpBuilder singleBuilder,
157
+ OpBuilder parallelBuilder) -> SmallVector<Value> {
121
158
IRMapping singleMapping = rootMapping;
159
+ SmallVector<Value> copyPrivate;
122
160
123
161
for (Operation &op : llvm::make_range (sr.begin , sr.end )) {
124
162
singleBuilder.clone (op, singleMapping);
125
163
if (isSafeToParallelize (&op)) {
126
- rootBuilder .clone (op, rootMapping);
164
+ parallelBuilder .clone (op, rootMapping);
127
165
} else {
128
166
// Prepare reloaded values for results of operations that cannot be
129
167
// safely parallelized and which are used after the region `sr`
@@ -132,16 +170,21 @@ static void parallelizeRegion(Region &sourceRegion, Region &targetRegion,
132
170
Operation *user = use.getOwner ();
133
171
while (user->getParentOp () != parentOp)
134
172
user = user->getParentOp ();
135
- if (!(user->isBeforeInBlock (&*sr.end ) &&
136
- sr.begin ->isBeforeInBlock (user))) {
137
- // We need to reload
138
- mapReloadedValue (use.get (), singleBuilder, singleMapping);
173
+ // TODO we need to look at transitively used vals
174
+ if (true || !(user->isBeforeInBlock (&*sr.end ) &&
175
+ sr.begin ->isBeforeInBlock (user))) {
176
+ auto alloc =
177
+ mapReloadedValue (use.get (), allocaBuilder, singleBuilder,
178
+ parallelBuilder, singleMapping);
179
+ if (alloc)
180
+ copyPrivate.push_back (alloc);
139
181
}
140
182
}
141
183
}
142
184
}
143
185
}
144
186
singleBuilder.create <omp::TerminatorOp>(loc);
187
+ return copyPrivate;
145
188
};
146
189
147
190
// TODO Need to handle these (clone them) in dominator tree order
@@ -178,14 +221,45 @@ static void parallelizeRegion(Region &sourceRegion, Region &targetRegion,
178
221
for (auto [i, opOrSingle] : llvm::enumerate (regions)) {
179
222
bool isLast = i + 1 == regions.size ();
180
223
if (std::holds_alternative<SingleRegion>(opOrSingle)) {
224
+ OpBuilder singleBuilder (sourceRegion.getContext ());
225
+ Block *singleBlock = new Block ();
226
+ singleBuilder.setInsertionPointToStart (singleBlock);
227
+
228
+ OpBuilder allocaBuilder (sourceRegion.getContext ());
229
+ Block *allocaBlock = new Block ();
230
+ allocaBuilder.setInsertionPointToStart (allocaBlock);
231
+
232
+ OpBuilder parallelBuilder (sourceRegion.getContext ());
233
+ Block *parallelBlock = new Block ();
234
+ parallelBuilder.setInsertionPointToStart (parallelBlock);
235
+
181
236
omp::SingleOperands singleOperands;
182
237
if (isLast)
183
238
singleOperands.nowait = rootBuilder.getUnitAttr ();
239
+ auto insPtAtSingle = rootBuilder.saveInsertionPoint ();
240
+ singleOperands.copyprivateVars =
241
+ moveToSingle (std::get<SingleRegion>(opOrSingle), allocaBuilder,
242
+ singleBuilder, parallelBuilder);
243
+ for (auto var : singleOperands.copyprivateVars ) {
244
+ Type ty;
245
+ if (auto firAlloca = var.getDefiningOp <fir::AllocaOp>()) {
246
+ ty = firAlloca.getAllocatedType ();
247
+ } else {
248
+ llvm_unreachable (" unexpected" );
249
+ }
250
+ mlir::func::FuncOp funcOp =
251
+ createCopyFunc (loc, var.getType (), firCopyFuncBuilder);
252
+ singleOperands.copyprivateSyms .push_back (SymbolRefAttr::get (funcOp));
253
+ }
184
254
omp::SingleOp singleOp =
185
255
rootBuilder.create <omp::SingleOp>(loc, singleOperands);
186
- OpBuilder singleBuilder (singleOp);
187
- singleBuilder.createBlock (&singleOp.getRegion ());
188
- moveToSingle (std::get<SingleRegion>(opOrSingle), singleBuilder);
256
+ singleOp.getRegion ().push_back (singleBlock);
257
+ rootBuilder.getInsertionBlock ()->getOperations ().splice (
258
+ rootBuilder.getInsertionPoint (), parallelBlock->getOperations ());
259
+ targetRegion.front ().getOperations ().splice (
260
+ singleOp->getIterator (), allocaBlock->getOperations ());
261
+ delete allocaBlock;
262
+ delete parallelBlock;
189
263
} else {
190
264
auto op = std::get<Operation *>(opOrSingle);
191
265
if (auto wslw = dyn_cast<omp::WorkshareLoopWrapperOp>(op)) {
0 commit comments