Skip to content

Commit 27735ad

Browse files
[mlir][Codegen] Remove workaround for handling consumer fusion along multiple operands.
With llvm/llvm-project#145193 it is possible to tile and fuse consumers when the consumer uses multiple results of the tiled loop (as long the as the slices of the uses/operands are consistent w.r.t to their use in the consumer). This removes the need for the workaround that was added to handle such cases and generalizes the cases of consumer fusion that can be handled. Fixes iree-org#21087 Signed-off-by: MaheshRavishankar <[email protected]>
1 parent 0ec7c6d commit 27735ad

File tree

4 files changed

+157
-141
lines changed

4 files changed

+157
-141
lines changed

compiler/src/iree/compiler/Codegen/Common/TileAndFuseUtils.cpp

Lines changed: 56 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
#include "iree/compiler/Codegen/Common/TileAndFuseUtils.h"
88
#include "llvm/Support/Debug.h"
9+
#include "mlir/Analysis/TopologicalSortUtils.h"
910
#include "mlir/Dialect/Linalg/IR/Linalg.h"
1011

1112
#define DEBUG_TYPE "iree-codegen-common-tile-and-fuse-utils"
@@ -46,36 +47,6 @@ void fuseProducersOfSlices(RewriterBase &rewriter,
4647
}
4748
}
4849

49-
bool warForConsumerFusionSSAViolation(
50-
Operation *rootOp,
51-
const llvm::SmallDenseSet<Operation *> &tiledAndFusedOps) {
52-
auto linalgRootOp = dyn_cast<linalg::LinalgOp>(rootOp);
53-
if (!linalgRootOp) {
54-
return false;
55-
}
56-
SmallVector<utils::IteratorType> iteratorTypes =
57-
linalgRootOp.getIteratorTypesArray();
58-
for (AffineMap map :
59-
llvm::map_range(linalgRootOp.getIndexingMaps(), [](Attribute attr) {
60-
return cast<AffineMapAttr>(attr).getValue();
61-
})) {
62-
if (!compressUnusedDims(map).isIdentity()) {
63-
return false;
64-
}
65-
}
66-
67-
for (OpOperand &use : linalgRootOp->getUses()) {
68-
auto linalgUser = dyn_cast<linalg::LinalgOp>(use.getOwner());
69-
if (!linalgUser) {
70-
return false;
71-
}
72-
if (!linalgUser.getMatchingIndexingMap(&use).isIdentity()) {
73-
return false;
74-
}
75-
}
76-
return true;
77-
}
78-
7950
void collectTiledAndFusedOps(Operation *rootOp,
8051
llvm::SmallDenseSet<Operation *> &result) {
8152
SmallVector<Operation *> worklist;
@@ -108,56 +79,71 @@ void collectTiledAndFusedOps(Operation *rootOp,
10879
FailureOr<std::queue<Operation *>>
10980
fuseConsumersIntoForall(RewriterBase &rewriter, Operation *tiledOp,
11081
MutableArrayRef<LoopLikeOpInterface> loops,
111-
bool useWARForConsumerFusionSSAViolation) {
112-
auto addCandidateSlices =
113-
[](Operation *fusedOp,
114-
std::queue<tensor::ParallelInsertSliceOp> &candidates) {
115-
for (auto *userOp : fusedOp->getResults().getUsers()) {
116-
if (auto sliceOp =
117-
llvm::dyn_cast<tensor::ParallelInsertSliceOp>(userOp)) {
118-
candidates.push(sliceOp);
119-
}
120-
}
121-
};
122-
82+
std::function<bool(Operation *)> filterFn) {
12383
// Collect the candidate slices which can be potential consumers that can be
12484
// fused.
125-
std::queue<tensor::ParallelInsertSliceOp> candidates;
126-
addCandidateSlices(tiledOp, candidates);
85+
std::queue<SmallVector<Operation *>> candidates;
86+
DenseSet<tensor::ParallelInsertSliceOp> allCandidates;
87+
auto addCandidateSlices = [&candidates, &allCandidates,
88+
&filterFn](Operation *fusedOp) {
89+
for (auto *userOp : fusedOp->getResults().getUsers()) {
90+
auto sliceOp = dyn_cast<tensor::ParallelInsertSliceOp>(userOp);
91+
if (!sliceOp || allCandidates.contains(sliceOp)) {
92+
continue;
93+
}
94+
95+
auto currLoop =
96+
cast<scf::ForallOp>(sliceOp->getParentOp()->getParentOp());
97+
OpResult loopResult = currLoop.getTiedOpResult(
98+
currLoop.getTiedOpOperand(cast<BlockArgument>(sliceOp.getDest())));
99+
auto users = llvm::to_vector(
100+
llvm::make_filter_range(loopResult.getUsers(), filterFn));
101+
if (users.empty()) {
102+
continue;
103+
}
104+
mlir::computeTopologicalSorting(users);
105+
106+
Operation *fusableUser = users.front();
107+
// Check all operands from the `scf.forall`
108+
SmallVector<OpResult> loopResults;
109+
for (OpOperand &opOperand : fusableUser->getOpOperands()) {
110+
if (opOperand.get().getDefiningOp() == currLoop.getOperation()) {
111+
loopResults.push_back(cast<OpResult>(opOperand.get()));
112+
}
113+
}
114+
115+
SmallVector<Operation *> fusedSlices;
116+
for (auto result : loopResults) {
117+
BlockArgument tiedBlockArg =
118+
currLoop.getTiedBlockArgument(currLoop.getTiedOpOperand(result));
119+
SmallVector<tensor::ParallelInsertSliceOp> slices = llvm::map_to_vector(
120+
currLoop.getCombiningOps(tiedBlockArg), [](Operation *op) {
121+
return cast<tensor::ParallelInsertSliceOp>(op);
122+
});
123+
llvm::append_range(fusedSlices, slices);
124+
allCandidates.insert_range(slices);
125+
}
126+
if (!fusedSlices.empty()) {
127+
candidates.emplace(std::move(fusedSlices));
128+
}
129+
}
130+
};
131+
132+
addCandidateSlices(tiledOp);
127133

128134
std::queue<Operation *> newFusionOpportunities;
129135
while (!candidates.empty()) {
130136

131137
// Traverse the slices in BFS fashion.
132-
tensor::ParallelInsertSliceOp candidateSliceOp = candidates.front();
138+
SmallVector<Operation *> candidateSlices = candidates.front();
133139
candidates.pop();
134140

135141
FailureOr<scf::SCFFuseConsumerOfSliceResult> fusedResult =
136-
mlir::scf::tileAndFuseConsumerOfSlices(
137-
rewriter, candidateSliceOp.getOperation(), loops);
142+
mlir::scf::tileAndFuseConsumerOfSlices(rewriter, candidateSlices,
143+
loops);
138144
if (failed(fusedResult)) {
139-
LLVM_DEBUG(llvm::dbgs() << "failed to fuse consumer of slice: "
140-
<< candidateSliceOp << "\n");
141-
continue;
142-
}
143-
144-
// Implement the WAR for consumer fusion SSA violation (as described below
145-
// in the comments for `warForConsumerFusionSSAViolation`)
146-
if (useWARForConsumerFusionSSAViolation) {
147-
for (auto [tiledOpResult, loopResult] :
148-
llvm::zip(tiledOp->getResults(), loops.back()->getResults())) {
149-
for (OpOperand &use : loopResult.getUses()) {
150-
Operation *user = use.getOwner();
151-
if (user->getParentOp() != loops.back()) {
152-
continue;
153-
}
154-
auto slice = dyn_cast<tensor::ExtractSliceOp>(user);
155-
if (!slice) {
156-
return failure();
157-
}
158-
rewriter.replaceAllOpUsesWith(slice, tiledOpResult);
159-
}
160-
}
145+
return candidateSlices.front()->emitOpError(
146+
"failed to fuse consumer of slice");
161147
}
162148

163149
// Replace the original consumer operation with the tiled implementation.
@@ -168,8 +154,7 @@ fuseConsumersIntoForall(RewriterBase &rewriter, Operation *tiledOp,
168154
// values produced by operations that implement the `TilingInterface`.
169155
// Add these operations to the worklist.
170156
addCandidateSlices(
171-
fusedResult->tiledAndFusedConsumerOperands.front()->getOwner(),
172-
candidates);
157+
fusedResult->tiledAndFusedConsumerOperands.front()->getOwner());
173158

174159
// Add the list of new producer fusion opportunities.
175160
for (auto tiledOp : fusedResult.value().tiledOps) {

compiler/src/iree/compiler/Codegen/Common/TileAndFuseUtils.h

Lines changed: 6 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -24,66 +24,6 @@ void fuseProducersOfSlices(RewriterBase &rewriter,
2424
scf::SCFTileAndFuseOptions &options,
2525
MutableArrayRef<LoopLikeOpInterface> loops);
2626

27-
/// Consider the following case
28-
///
29-
/// ```mlir
30-
/// %0:2 = linalg.generic {
31-
/// indexing_maps = [....,
32-
/// affine_map<(d0, d1, d2) -> (d0, d1),
33-
/// affine_map<(d0, d1, d2) -> (d0, d1)>]}
34-
/// %1 = linalg.generic ins(%0#0, %0#1) {
35-
/// indexing_maps = [affine_map<(d0, d1) -> (d0, d1),
36-
/// affine_map<(d0, d1) -> (d0, d1)]}
37-
/// ```
38-
///
39-
/// After tiling the first op we get
40-
///
41-
/// ```
42-
/// %0:2 = scf.forall ... {
43-
/// %1:2 = linalg.generic {
44-
/// indexing_maps = [....,
45-
/// affine_map<(d0, d1, d2) -> (d0, d1),
46-
/// affine_map<(d0, d1, d2) -> (d0, d1)>]}
47-
/// }
48-
/// }
49-
/// %2 = linalg.generic ins(%0#0, %0#1) {
50-
/// indexing_maps = [affine_map<(d0, d1) -> (d0, d1),
51-
/// affine_map<(d0, d1) -> (d0, d1)]}
52-
/// ```
53-
///
54-
/// Due to a quirk of the fusion of consumers, fusing this consumer into the
55-
/// loop results in
56-
///
57-
/// ```
58-
/// %0:2 = scf.forall ... {
59-
/// %1:2 = linalg.generic {
60-
/// indexing_maps = [....,
61-
/// affine_map<(d0, d1, d2) -> (d0, d1),
62-
/// affine_map<(d0, d1, d2) -> (d0, d1)>]}
63-
/// %2 = tensor.extract_slice %0#1 [...]
64-
/// %3 = linalg.generic ins(%1#0, %2) {
65-
/// indexing_maps = [affine_map<(d0, d1) -> (d0, d1),
66-
/// affine_map<(d0, d1) -> (d0, d1)]}
67-
/// }
68-
/// }
69-
/// ```
70-
///
71-
/// This is an SSA violation because of `%0#1` being used in the loop. This
72-
/// needs to be fixed upstream, but for cases where
73-
/// 1. The root operation produces results using an identity indexing map (when
74-
/// ignoring the iteration space dimensions corresponding to the reduction
75-
/// loops)
76-
/// 2. For all consumers of the results of the root operation, access the data
77-
/// using identity indexing map then for each consumer fusion step it is valid
78-
/// to replace all uses of slices of the outer loop that occur within the loop
79-
/// with the correponding tiled result value.
80-
/// This is a workaround till upstream transformation can fix this issue. The
81-
/// following method is testing if such a case exists to implement the
82-
/// work-around.
83-
bool warForConsumerFusionSSAViolation(
84-
Operation *rootOp,
85-
const llvm::SmallDenseSet<Operation *> &tiledAndFusedOps);
86-
8727
/// Starting from `op` walk all operands backwards to find all
8828
/// potentially fusible operations, i.e. operations that implement
8929
/// the `TilingInterface`.
@@ -94,10 +34,12 @@ void collectTiledAndFusedOps(Operation *rootOp,
9434
// Returns a list of new `tensor.extract_slice` ops with new fusion
9535
// opportunities, as well as the new surrounding `scf.forall` (because consumer
9636
// fusion replaces the loop).
97-
FailureOr<std::queue<Operation *>>
98-
fuseConsumersIntoForall(RewriterBase &rewriter, Operation *tiledOp,
99-
MutableArrayRef<LoopLikeOpInterface> loops,
100-
bool useWARForConsumerFusionSSAViolation);
37+
FailureOr<std::queue<Operation *>> fuseConsumersIntoForall(
38+
RewriterBase &rewriter, Operation *tiledOp,
39+
MutableArrayRef<LoopLikeOpInterface> loops,
40+
std::function<bool(Operation *)> filterFn = [](Operation *) {
41+
return true;
42+
});
10143

10244
} // namespace mlir::iree_compiler
10345

compiler/src/iree/compiler/Codegen/Common/TileDispatchUsingForall.cpp

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -331,8 +331,6 @@ void TileAndDistributeToWorkgroupsUsingForallOpPass::runOnOperation() {
331331
mlir::DominanceInfo dominanceInfo(tilableOp);
332332
llvm::SmallDenseSet<Operation *> tiledAndFusedOps;
333333
collectTiledAndFusedOps(tilableOp, tiledAndFusedOps);
334-
bool useWARForConsumerFusionSSAViolation =
335-
warForConsumerFusionSSAViolation(tilableOp, tiledAndFusedOps);
336334

337335
llvm::DenseSet<Operation *> yieldReplacementsFor;
338336
for (auto op : tiledAndFusedOps) {
@@ -413,13 +411,20 @@ void TileAndDistributeToWorkgroupsUsingForallOpPass::runOnOperation() {
413411
return signalPassFailure();
414412
}
415413
for (auto [origValue, replacement] : tileAndFuseResult->replacements) {
416-
rewriter.replaceAllUsesWith(origValue, replacement);
414+
Value replacementCopy = replacement;
415+
rewriter.replaceUsesWithIf(origValue, replacement, [&](OpOperand &use) {
416+
Operation *user = use.getOwner();
417+
return !isa<tensor::DimOp>(user) &&
418+
dominanceInfo.dominates(replacementCopy, user);
419+
});
417420
}
418421
std::swap(tileAndFuseResult->loops, tilingLoops);
419422
Operation *rootTiledOp = tileAndFuseResult->tiledAndFusedOps.front();
420423
FailureOr<std::queue<Operation *>> newFusionOpportunities =
421424
fuseConsumersIntoForall(rewriter, rootTiledOp, tilingLoops,
422-
useWARForConsumerFusionSSAViolation);
425+
[&tiledAndFusedOps](Operation *op) {
426+
return tiledAndFusedOps.contains(op);
427+
});
423428
if (failed(newFusionOpportunities)) {
424429
rootTiledOp->emitOpError("failed to fuse consumers");
425430
return signalPassFailure();

compiler/src/iree/compiler/Codegen/Common/test/tile_and_distribute_workgroups_using_forall.mlir

Lines changed: 86 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -826,7 +826,7 @@ func.func @pad_fusion(%0 : tensor<?x?xf32>, %1 : tensor<?x?xf32>, %2 : tensor<?x
826826

827827
// -----
828828

829-
// Test 1 of 2 that are testing a work-around for SSA violation issue with consumer fusion upstream.
829+
// Test 1 of 2 that are testing fusion while considering multiple slices.
830830

831831
func.func @horizontal_fusion_consumer_fusion1(%arg0 : tensor<2x4096x640xf16>,
832832
%arg1 : tensor<10x64x640xf16>, %arg2 : tensor<10x64x640xf16>, %arg3 : tensor<10x64x640xf16>)
@@ -893,7 +893,7 @@ func.func @horizontal_fusion_consumer_fusion1(%arg0 : tensor<2x4096x640xf16>,
893893

894894
// -----
895895

896-
// Test 2 of 2 that are testing a work-around for SSA violation issue with consumer fusion upstream.
896+
// Test 2 of 2 that are testing fusion while considering multiple slices.
897897

898898
func.func @horizontal_fusion_consumer_fusion2(%arg0 : tensor<2x4096x640xi8>,
899899
%arg1 : tensor<2x640x640xi8>, %arg2 : tensor<2x640x640xi8>) -> tensor<2x4096x640xf16> {
@@ -989,3 +989,87 @@ func.func @only_producer_fusion_multiple_result(%arg0: tensor<77x4096xf16>, %arg
989989
// CHECK: linalg.generic
990990
// CHECK: linalg.generic
991991
// CHECK: return %[[RESULT]]#1, %[[RESULT]]#0
992+
993+
// -----
994+
995+
func.func @multi_slice_fusion_broadcast(%arg0: index, %arg1: tensor<3x?x32xi64>,
996+
%arg2: tensor<256x32xf32>, %arg3: tensor<32xf32>)
997+
-> (tensor<3x?x32x32xf32>, tensor<3x?x32x32xf32>) {
998+
%c32 = arith.constant 32 : index
999+
%c2_i64 = arith.constant 2 : i64
1000+
%cst = arith.constant 0.000000e+00 : f32
1001+
%cst_0 = arith.constant 3.200000e+01 : f32
1002+
%cst_1 = arith.constant 9.000000e+00 : f32
1003+
%0 = arith.divsi %arg0, %c32 : index
1004+
%1 = affine.apply affine_map<()[s0] -> (s0 floordiv 32)>()[%arg0]
1005+
%2 = tensor.empty(%1) : tensor<3x?x32x32xf32>
1006+
%3 = linalg.generic {
1007+
indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>,
1008+
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>],
1009+
iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
1010+
ins(%arg1 : tensor<3x?x32xi64>) outs(%2 : tensor<3x?x32x32xf32>) {
1011+
^bb0(%in: i64, %out: f32):
1012+
%8 = arith.index_cast %in : i64 to index
1013+
%9 = linalg.index 3 : index
1014+
%extracted = tensor.extract %arg2[%8, %9] : tensor<256x32xf32>
1015+
linalg.yield %extracted : f32
1016+
} -> tensor<3x?x32x32xf32>
1017+
%4 = tensor.empty(%0) : tensor<3x?x32xf32>
1018+
%5 = linalg.fill ins(%cst : f32)outs(%4 : tensor<3x?x32xf32>) -> tensor<3x?x32xf32>
1019+
%6 = linalg.generic {
1020+
indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
1021+
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>],
1022+
iterator_types = ["parallel", "parallel", "parallel", "reduction"]}
1023+
ins(%3 : tensor<3x?x32x32xf32>) outs(%5 : tensor<3x?x32xf32>)
1024+
attrs = {lowering_config = #iree_gpu.lowering_config<{reduction = [0, 0, 0, 4], thread = [1, 1, 1, 0], workgroup = [1, 1, 64, 0]}>} {
1025+
^bb0(%in: f32, %out: f32):
1026+
%8 = math.fpowi %in, %c2_i64 : f32, i64
1027+
%9 = arith.addf %8, %out : f32
1028+
linalg.yield %9 : f32
1029+
} -> tensor<3x?x32xf32>
1030+
%7 = linalg.generic {
1031+
indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d3)>,
1032+
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
1033+
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>,
1034+
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>],
1035+
iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
1036+
ins(%arg3, %3, %6 : tensor<32xf32>, tensor<3x?x32x32xf32>, tensor<3x?x32xf32>)
1037+
outs(%2 : tensor<3x?x32x32xf32>) {
1038+
^bb0(%in: f32, %in_2: f32, %in_3: f32, %out: f32):
1039+
%8 = arith.divf %in_3, %cst_0 : f32
1040+
%9 = arith.addf %8, %cst_1 : f32
1041+
%10 = math.rsqrt %9 : f32
1042+
%11 = arith.mulf %in_2, %10 : f32
1043+
%12 = arith.mulf %in, %11 : f32
1044+
linalg.yield %12 : f32
1045+
} -> tensor<3x?x32x32xf32>
1046+
return %3, %7 : tensor<3x?x32x32xf32>, tensor<3x?x32x32xf32>
1047+
}
1048+
// CHECK-LABEL: func @multi_slice_fusion_broadcast
1049+
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<3x?x32xi64>
1050+
// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: tensor<32xf32>
1051+
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
1052+
// CHECK-DAG: %[[C32:.+]] = arith.constant 32 : index
1053+
// CHECK: %[[EMPTY:.+]] = tensor.empty
1054+
// CHECK: %[[RESULT:.+]]:2 = scf.forall (%[[IV0:[a-zA-Z0-9]+]], %[[IV1:[a-zA-Z0-9]+]])
1055+
// CHECK-SAME: shared_outs(%[[INIT0:[a-zA-Z0-9]+]] = %[[EMPTY]], %[[INIT1:[a-zA-Z0-9]+]] = %[[EMPTY]])
1056+
// CHECK-DAG: %[[INIT0_SLICE:.+]] = tensor.extract_slice %[[INIT0]][%[[IV0]], %[[IV1]], 0, 0] [1, 1, 32, 32]
1057+
// CHECK-DAG: %[[ARG1_SLICE:.+]] = tensor.extract_slice %[[ARG1]][%[[IV0]], %[[IV1]], 0] [1, 1, 32]
1058+
// CHECK: %[[GENERIC0:.+]] = linalg.generic
1059+
// CHECK-SAME: ins(%[[ARG1_SLICE]] :
1060+
// CHECK-SAME: outs(%[[INIT0_SLICE]] :
1061+
// CHECK: %[[CAST0:.+]] = tensor.cast %[[GENERIC0]]
1062+
// CHECK: %[[EMPTYTILE:.+]] = tensor.empty() : tensor<1x1x32xf32>
1063+
// CHECK: %[[FILL:.+]] = linalg.fill
1064+
// CHECK-SAME: outs(%[[EMPTYTILE]] :
1065+
// CHECK: %[[GENERIC1:.+]] = linalg.generic
1066+
// CHECK-SAME: ins(%[[GENERIC0]] :
1067+
// CHECK-SAME: outs(%[[FILL]] :
1068+
// CHECK: %[[INIT1_SLICE:.+]] = tensor.extract_slice %[[INIT1]][%[[IV0]], %[[IV1]], 0, 0] [1, 1, 32, 32] [1, 1, 1, 1]
1069+
// CHECK: %[[GENERIC2:.+]] = linalg.generic
1070+
// CHECK-SAME: ins(%[[ARG3]], %[[GENERIC0]], %[[GENERIC1]] :
1071+
// CHECK-SAME: outs(%[[INIT1_SLICE]] :
1072+
// CHECK: %[[CAST1:.+]] = tensor.cast %[[GENERIC2]]
1073+
// CHECK-DAG: tensor.parallel_insert_slice %[[CAST0]] into %[[INIT0]][%[[IV0]], %[[IV1]], %[[C0]], 0] [1, 1, %[[C32]], 32]
1074+
// CHECK-DAG: tensor.parallel_insert_slice %[[CAST1]] into %[[INIT1]][%[[IV0]], %[[IV1]], %[[C0]], 0] [1, 1, %[[C32]], 32]
1075+
// CHECK: return %[[RESULT]]#0, %[[RESULT]]#1

0 commit comments

Comments
 (0)