|
13 | 13 | #include <flang/Optimizer/Dialect/FIRType.h> |
14 | 14 | #include <flang/Optimizer/HLFIR/HLFIROps.h> |
15 | 15 | #include <flang/Optimizer/OpenMP/Passes.h> |
| 16 | +#include <llvm/ADT/BreadthFirstIterator.h> |
16 | 17 | #include <llvm/ADT/STLExtras.h> |
17 | 18 | #include <llvm/ADT/SmallVectorExtras.h> |
18 | 19 | #include <llvm/ADT/iterator_range.h> |
19 | 20 | #include <llvm/Support/ErrorHandling.h> |
20 | 21 | #include <mlir/Dialect/Arith/IR/Arith.h> |
21 | 22 | #include <mlir/Dialect/LLVMIR/LLVMTypes.h> |
| 23 | +#include <mlir/Dialect/OpenMP/OpenMPClauseOperands.h> |
22 | 24 | #include <mlir/Dialect/OpenMP/OpenMPDialect.h> |
23 | 25 | #include <mlir/Dialect/SCF/IR/SCF.h> |
24 | 26 | #include <mlir/IR/BuiltinOps.h> |
@@ -161,7 +163,8 @@ static void cleanupBlock(Block *block) { |
161 | 163 | } |
162 | 164 |
|
163 | 165 | static void parallelizeRegion(Region &sourceRegion, Region &targetRegion, |
164 | | - IRMapping &rootMapping, Location loc) { |
| 166 | + IRMapping &rootMapping, Location loc, |
| 167 | + mlir::DominanceInfo &di) { |
165 | 168 | OpBuilder rootBuilder(sourceRegion.getContext()); |
166 | 169 | ModuleOp m = sourceRegion.getParentOfType<ModuleOp>(); |
167 | 170 | OpBuilder copyFuncBuilder(m.getBodyRegion()); |
@@ -214,14 +217,19 @@ static void parallelizeRegion(Region &sourceRegion, Region &targetRegion, |
214 | 217 | return copyPrivate; |
215 | 218 | }; |
216 | 219 |
|
217 | | - // TODO Need to handle these (clone them) in dominator tree order |
218 | 220 | for (Block &block : sourceRegion) { |
219 | | - rootBuilder.createBlock( |
| 221 | + Block *targetBlock = rootBuilder.createBlock( |
220 | 222 | &targetRegion, {}, block.getArgumentTypes(), |
221 | 223 | llvm::map_to_vector(block.getArguments(), |
222 | 224 | [](BlockArgument arg) { return arg.getLoc(); })); |
223 | | - Operation *terminator = block.getTerminator(); |
| 225 | + rootMapping.map(&block, targetBlock); |
| 226 | + rootMapping.map(block.getArguments(), targetBlock->getArguments()); |
| 227 | + } |
224 | 228 |
|
| 229 | + auto handleOneBlock = [&](Block &block) { |
| 230 | + Block &targetBlock = *rootMapping.lookup(&block); |
| 231 | + rootBuilder.setInsertionPointToStart(&targetBlock); |
| 232 | + Operation *terminator = block.getTerminator(); |
225 | 233 | SmallVector<std::variant<SingleRegion, Operation *>> regions; |
226 | 234 |
|
227 | 235 | auto it = block.begin(); |
@@ -298,12 +306,21 @@ static void parallelizeRegion(Region &sourceRegion, Region &targetRegion, |
298 | 306 | Operation *cloned = rootBuilder.cloneWithoutRegions(*op, rootMapping); |
299 | 307 | for (auto [region, clonedRegion] : |
300 | 308 | llvm::zip(op->getRegions(), cloned->getRegions())) |
301 | | - parallelizeRegion(region, clonedRegion, rootMapping, loc); |
| 309 | + parallelizeRegion(region, clonedRegion, rootMapping, loc, di); |
302 | 310 | } |
303 | 311 | } |
304 | 312 | } |
305 | 313 |
|
306 | 314 | rootBuilder.clone(*block.getTerminator(), rootMapping); |
| 315 | + }; |
| 316 | + |
| 317 | + if (sourceRegion.hasOneBlock()) { |
| 318 | + handleOneBlock(sourceRegion.front()); |
| 319 | + } else { |
| 320 | + auto &domTree = di.getDomTree(&sourceRegion); |
| 321 | + for (auto node : llvm::breadth_first(domTree.getRootNode())) { |
| 322 | + handleOneBlock(*node->getBlock()); |
| 323 | + } |
307 | 324 | } |
308 | 325 |
|
309 | 326 | for (Block &targetBlock : targetRegion) |
@@ -336,47 +353,46 @@ static void parallelizeRegion(Region &sourceRegion, Region &targetRegion, |
336 | 353 | /// |
337 | 354 | /// Note that we allocate temporary memory for values in omp.single's which need |
338 | 355 | /// to be accessed in all threads in the closest omp.parallel |
339 | | -void lowerWorkshare(mlir::omp::WorkshareOp wsOp) { |
| 356 | +LogicalResult lowerWorkshare(mlir::omp::WorkshareOp wsOp, DominanceInfo &di) { |
340 | 357 | Location loc = wsOp->getLoc(); |
341 | 358 | IRMapping rootMapping; |
342 | 359 |
|
343 | 360 | OpBuilder rootBuilder(wsOp); |
344 | 361 |
|
345 | | - // TODO We need something like an scf;execute here, but that is not registered |
346 | | - // so using fir.if for now but it looks like it does not support multiple |
347 | | - // blocks so it doesnt work for multi block case... |
348 | | - auto ifOp = rootBuilder.create<fir::IfOp>( |
349 | | - loc, rootBuilder.create<arith::ConstantIntOp>(loc, 1, 1), false); |
350 | | - ifOp.getThenRegion().front().erase(); |
351 | | - |
352 | | - parallelizeRegion(wsOp.getRegion(), ifOp.getThenRegion(), rootMapping, loc); |
353 | | - |
354 | | - Operation *terminatorOp = ifOp.getThenRegion().back().getTerminator(); |
355 | | - assert(isa<omp::TerminatorOp>(terminatorOp)); |
356 | | - OpBuilder termBuilder(terminatorOp); |
357 | | - |
| 362 | + // TODO We need something like an scf.execute here, but that is not registered |
| 363 | + // so using omp.workshare as a placeholder. We need this op as our |
| 364 | + // parallelizeRegion works on regions and not blocks. |
| 365 | + omp::WorkshareOp newOp = |
| 366 | + rootBuilder.create<omp::WorkshareOp>(loc, omp::WorkshareOperands()); |
358 | 367 | if (!wsOp.getNowait()) |
359 | | - termBuilder.create<omp::BarrierOp>(loc); |
360 | | - |
361 | | - termBuilder.create<fir::ResultOp>(loc, ValueRange()); |
362 | | - |
363 | | - terminatorOp->erase(); |
| 368 | + rootBuilder.create<omp::BarrierOp>(loc); |
| 369 | + |
| 370 | + parallelizeRegion(wsOp.getRegion(), newOp.getRegion(), rootMapping, loc, di); |
| 371 | + |
| 372 | + if (wsOp.getRegion().getBlocks().size() != 1) |
| 373 | + return failure(); |
| 374 | + |
| 375 | + // Inline the contents of the placeholder workshare op into its parent block. |
| 376 | + Block *theBlock = &newOp.getRegion().front(); |
| 377 | + Operation *term = theBlock->getTerminator(); |
| 378 | + Block *parentBlock = wsOp->getBlock(); |
| 379 | + parentBlock->getOperations().splice(newOp->getIterator(), |
| 380 | + theBlock->getOperations()); |
| 381 | + assert(term->getNumOperands() == 0); |
| 382 | + term->erase(); |
| 383 | + newOp->erase(); |
364 | 384 | wsOp->erase(); |
365 | | - |
366 | | - return; |
| 385 | + return success(); |
367 | 386 | } |
368 | 387 |
|
369 | 388 | class LowerWorksharePass |
370 | 389 | : public flangomp::impl::LowerWorkshareBase<LowerWorksharePass> { |
371 | 390 | public: |
372 | 391 | void runOnOperation() override { |
373 | | - SmallPtrSet<Operation *, 8> parents; |
| 392 | + mlir::DominanceInfo &di = getAnalysis<mlir::DominanceInfo>(); |
374 | 393 | getOperation()->walk([&](mlir::omp::WorkshareOp wsOp) { |
375 | | - Operation *isolatedParent = |
376 | | - wsOp->getParentWithTrait<OpTrait::IsIsolatedFromAbove>(); |
377 | | - parents.insert(isolatedParent); |
378 | | - |
379 | | - lowerWorkshare(wsOp); |
| 394 | + if (failed(lowerWorkshare(wsOp, di))) |
| 395 | + signalPassFailure(); |
380 | 396 | }); |
381 | 397 | } |
382 | 398 | }; |
|
0 commit comments