Skip to content

Commit a42a2ca

Browse files
authored
Avoid buffer hoisting from parallel loops (#90735)
This change corrects an invalid behavior in pass `--buffer-loop-hoisting`. The pass is in charge of extracting buffer allocations (e.g., `memref.alloca`) from loop regions (e.g., `scf.for`) when possible. This works OK for looks with sequential execution semantics. However, a buffer allocated in the body of a parallel loop may be concurrently accessed by multiple thread to store its local data. Extracting such buffer from the loop causes all threads to wrongly share the same memory region. In the following example, dimension 1 of the input tensor is reversed. Dimension 0 is traversed with a parallel loop. ``` func.func @f(%input: memref<2x3xf32>) -> memref<2x3xf32> { %c0 = index.constant 0 %c1 = index.constant 1 %c2 = index.constant 2 %c3 = index.constant 3 %output = memref.alloc() : memref<2x3xf32> scf.parallel (%index) = (%c0) to (%c2) step (%c1) { // Create subviews for working input and output slices %input_slice = memref.subview %input[%index, 2][1, 3][1, -1] : memref<2x3xf32> to memref<1x3xf32, strided<[3, -1], offset: ?>> %output_slice = memref.subview %output[%index, 0][1, 3][1, 1] : memref<2x3xf32> to memref<1x3xf32, strided<[3, 1], offset: ?>> // Copy the input slice into this temporary buffer. This intermediate // copy is unnecessary, but is used for illustration purposes. %temp = memref.alloc() : memref<1x3xf32> memref.copy %input_slice, %temp : memref<1x3xf32, strided<[3, -1], offset: ?>> to memref<1x3xf32> // Copy temporary buffer into output slice memref.copy %temp, %output_slice : memref<1x3xf32> to memref<1x3xf32, strided<[3, 1], offset: ?>> scf.reduce } return %output : memref<2x3xf32> } ``` The patch submitted here prevents `%temp = memref.alloc() : memref<1x3xf32>` from being hoisted when the containing op is `scf.parallel` or `scf.forall`. A new op trait called `HasParallelRegion` is introduced and assigned to these two ops to indicate that their regions have parallel execution semantics. @joker-eph @ftynse @nicolasvasilache @sabauma
1 parent 1022636 commit a42a2ca

File tree

5 files changed

+82
-7
lines changed

5 files changed

+82
-7
lines changed

mlir/include/mlir/Dialect/SCF/IR/SCFOps.td

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -307,7 +307,8 @@ def ForallOp : SCF_Op<"forall", [
307307
RecursiveMemoryEffects,
308308
SingleBlockImplicitTerminator<"scf::InParallelOp">,
309309
DeclareOpInterfaceMethods<RegionBranchOpInterface>,
310-
DestinationStyleOpInterface
310+
DestinationStyleOpInterface,
311+
HasParallelRegion
311312
]> {
312313
let summary = "evaluate a block multiple times in parallel";
313314
let description = [{
@@ -764,7 +765,8 @@ def ParallelOp : SCF_Op<"parallel",
764765
"getSingleLowerBound", "getSingleUpperBound", "getSingleStep"]>,
765766
RecursiveMemoryEffects,
766767
DeclareOpInterfaceMethods<RegionBranchOpInterface>,
767-
SingleBlockImplicitTerminator<"scf::ReduceOp">]> {
768+
SingleBlockImplicitTerminator<"scf::ReduceOp">,
769+
HasParallelRegion]> {
768770
let summary = "parallel for operation";
769771
let description = [{
770772
The "scf.parallel" operation represents a loop nest taking 4 groups of SSA

mlir/include/mlir/Interfaces/LoopLikeInterface.h

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,31 @@ namespace detail {
2929
/// Verify invariants of the LoopLikeOpInterface.
3030
LogicalResult verifyLoopLikeOpInterface(Operation *op);
3131
} // namespace detail
32+
33+
//===----------------------------------------------------------------------===//
34+
// Traits
35+
//===----------------------------------------------------------------------===//
36+
37+
namespace OpTrait {
38+
// A trait indicating that the single region contained in the operation has
39+
// parallel execution semantics. This may have implications in a certain pass.
40+
// For example, buffer hoisting is illegal in parallel loops, and local buffers
41+
// may be accessed by parallel threads simultaneously.
42+
template <typename ConcreteType>
43+
class HasParallelRegion : public TraitBase<ConcreteType, HasParallelRegion> {
44+
public:
45+
static LogicalResult verifyTrait(Operation *op) {
46+
return impl::verifyOneRegion(op);
47+
}
48+
};
49+
50+
} // namespace OpTrait
3251
} // namespace mlir
3352

53+
//===----------------------------------------------------------------------===//
54+
// Interfaces
55+
//===----------------------------------------------------------------------===//
56+
3457
/// Include the generated interface declarations.
3558
#include "mlir/Interfaces/LoopLikeInterface.h.inc"
3659

mlir/include/mlir/Interfaces/LoopLikeInterface.td

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,10 @@
1515

1616
include "mlir/IR/OpBase.td"
1717

18+
//===----------------------------------------------------------------------===//
19+
// Interfaces
20+
//===----------------------------------------------------------------------===//
21+
1822
def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
1923
let description = [{
2024
Contains helper functions to query properties and perform transformations
@@ -371,4 +375,11 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
371375
}];
372376
}
373377

378+
//===----------------------------------------------------------------------===//
379+
// Traits
380+
//===----------------------------------------------------------------------===//
381+
382+
// Op contains a region with parallel execution semantics
383+
def HasParallelRegion : NativeOpTrait<"HasParallelRegion">;
384+
374385
#endif // MLIR_INTERFACES_LOOPLIKEINTERFACE

mlir/lib/Dialect/Bufferization/Transforms/BufferOptimizations.cpp

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,12 @@ static bool isLoop(Operation *op) {
5959
return regionInterface.hasLoop();
6060
}
6161

62+
/// Return whether the given operation is a loop with sequential execution
63+
/// semantics.
64+
static bool isSequentialLoop(Operation *op) {
65+
return !op->hasTrait<OpTrait::HasParallelRegion>() && isLoop(op);
66+
}
67+
6268
/// Returns true if the given operation implements the AllocationOpInterface
6369
/// and it supports the dominate block hoisting.
6470
static bool allowAllocDominateBlockHoisting(Operation *op) {
@@ -338,12 +344,13 @@ struct BufferAllocationLoopHoistingState : BufferAllocationHoistingStateBase {
338344
return dependencyBlock ? dependencyBlock : nullptr;
339345
}
340346

341-
/// Returns true if the given operation represents a loop and one of the
342-
/// aliases caused the `aliasDominatorBlock` to be "above" the block of the
343-
/// given loop operation. If this is the case, it indicates that the
344-
/// allocation is passed via a back edge.
347+
/// Returns true if the given operation represents a loop with sequential
348+
/// execution semantics and one of the aliases caused the
349+
/// `aliasDominatorBlock` to be "above" the block of the given loop operation.
350+
/// If this is the case, it indicates that the allocation is passed via a back
351+
/// edge.
345352
bool isLegalPlacement(Operation *op) {
346-
return isLoop(op) &&
353+
return isSequentialLoop(op) &&
347354
!dominators->dominates(aliasDominatorBlock, op->getBlock());
348355
}
349356

mlir/test/Dialect/Bufferization/Transforms/buffer-loop-hoisting.mlir

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -461,6 +461,38 @@ func.func @partial_hoist_multiple_loop_dependency(
461461

462462
// -----
463463

464+
// CHECK-LABEL: func @no_hoist_parallel
465+
func.func @no_hoist_parallel(
466+
%lb: index,
467+
%ub: index,
468+
%step: index) {
469+
scf.parallel (%i) = (%lb) to (%ub) step (%step) {
470+
%0 = memref.alloc() : memref<2xf32>
471+
scf.reduce
472+
}
473+
return
474+
}
475+
476+
// CHECK: memref.alloc
477+
// CHECK-NEXT: scf.reduce
478+
479+
// -----
480+
481+
func.func @no_hoist_forall(
482+
%lb: index,
483+
%ub: index,
484+
%step: index) {
485+
scf.forall (%i) = (%lb) to (%ub) step (%step) {
486+
%1 = memref.alloc() : memref<2xf32>
487+
}
488+
return
489+
}
490+
491+
// CHECK: scf.forall
492+
// CHECK-NEXT: memref.alloc
493+
494+
// -----
495+
464496
// Test with allocas to ensure that op is also considered.
465497

466498
// CHECK-LABEL: func @hoist_alloca

0 commit comments

Comments
 (0)