Skip to content

Commit 469fa4e

Browse files
pravg-amdhhkit
authored andcommitted
Revert "[codegen] more consumer fusion (iree-org#21521)" (iree-org#21819)
This reverts commit 4d91ffb. The above commit causes failure in compilation of llama 405B fp4 model. Ticket to track the same iree-org#21814 Signed-off-by: Praveen G <[email protected]> Signed-off-by: Ivan Ho <[email protected]>
1 parent d337ad2 commit 469fa4e

File tree

4 files changed

+24
-197
lines changed

4 files changed

+24
-197
lines changed

compiler/src/iree/compiler/Codegen/Common/TileAndFuseUtils.cpp

Lines changed: 10 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h"
1010
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenInterfaces.h"
1111
#include "llvm/ADT/STLExtras.h"
12-
#include "llvm/ADT/SmallPtrSet.h"
1312
#include "llvm/Support/Casting.h"
1413
#include "llvm/Support/Debug.h"
1514
#include "mlir/Analysis/TopologicalSortUtils.h"
@@ -88,51 +87,16 @@ void collectTiledAndFusedOps(Operation *rootOp,
8887
}
8988
}
9089

91-
namespace {
92-
// Entry for the pseudo-priority queue of consumer fusion candidates. Contains
93-
// the consumer (fusableUser) that can be fused and the set of slice operations
94-
// in the loop to fuse into that feed the consumer.
95-
struct ConsumerFusionQueueEntry {
96-
ConsumerFusionQueueEntry(SmallVector<Operation *> &&slices,
97-
Operation *fusableUser)
98-
: slices(std::move(slices)), fusableUser(fusableUser) {}
99-
100-
SmallVector<Operation *> slices;
101-
Operation *fusableUser;
102-
};
103-
} // namespace
104-
10590
FailureOr<std::queue<Operation *>>
106-
fuseConsumersIntoForall(RewriterBase &rewriter, ArrayRef<Operation *> tiledOps,
91+
fuseConsumersIntoForall(RewriterBase &rewriter, Operation *tiledOp,
10792
MutableArrayRef<LoopLikeOpInterface> loops,
10893
std::function<bool(Operation *)> filterFn) {
10994
// Collect the candidate slices which can be potential consumers that can be
110-
// fused. Keep them in a vector reverse-sorted by dominance: the candidate
111-
// dominating others comes last (so it can be cheaply popped from the vector).
112-
// The most-dominating candidate is to be fused first since not fusing it may
113-
// prevent dominated candidates to be fused:
114-
//
115-
// A
116-
// |
117-
// B
118-
// / |
119-
// | D
120-
// | /
121-
// C
122-
//
123-
// here, B must be fused before both C and D, and D must be fused before C.
124-
// Candidates are kept in a vector rather than a priority queue since we may
125-
// update them as fusion happens, in particular, more slices may need to be
126-
// handled. For example, fusing B with A will create a slice of B that will
127-
// need to be handled correctly.
128-
SmallVector<ConsumerFusionQueueEntry> candidates;
95+
// fused.
96+
std::queue<SmallVector<Operation *>> candidates;
12997
llvm::SmallDenseSet<tensor::ParallelInsertSliceOp> allCandidates;
13098
auto addCandidateSlices = [&candidates, &allCandidates,
13199
&filterFn](Operation *fusedOp) {
132-
// Dominance info recreated since op creation/movement in the fusion logic
133-
// invalidates it anyway.
134-
DominanceInfo dominanceInfo;
135-
136100
for (auto *userOp : fusedOp->getResults().getUsers()) {
137101
auto sliceOp = dyn_cast<tensor::ParallelInsertSliceOp>(userOp);
138102
if (!sliceOp || allCandidates.contains(sliceOp)) {
@@ -171,38 +135,22 @@ fuseConsumersIntoForall(RewriterBase &rewriter, ArrayRef<Operation *> tiledOps,
171135
allCandidates.insert_range(slices);
172136
}
173137
if (!fusedSlices.empty()) {
174-
ConsumerFusionQueueEntry entry(std::move(fusedSlices), fusableUser);
175-
176-
// Comparator that puts the dominating user last.
177-
auto comp = [&](const ConsumerFusionQueueEntry &lhs,
178-
const ConsumerFusionQueueEntry &rhs) {
179-
return dominanceInfo.properlyDominates(rhs.fusableUser,
180-
lhs.fusableUser);
181-
};
182-
183-
// If the fusable user is already a candidate, update it with the new
184-
// list of slices to handle. Otherwise, insert it into the right
185-
// position based on dominance.
186-
auto *it = llvm::lower_bound(candidates, entry, comp);
187-
if (it != candidates.end() && it->fusableUser == fusableUser)
188-
*it = std::move(entry);
189-
else
190-
candidates.insert(it, std::move(entry));
138+
candidates.emplace(std::move(fusedSlices));
191139
}
192140
}
193141
};
194142

195-
// Add slices from all tiled ops, not only the "main" one.
196-
for (Operation *tiledOp : tiledOps)
197-
addCandidateSlices(tiledOp);
143+
addCandidateSlices(tiledOp);
198144

199145
std::queue<Operation *> newFusionOpportunities;
200146
while (!candidates.empty()) {
201-
// Get the next candidate.
202-
ConsumerFusionQueueEntry entry = candidates.pop_back_val();
147+
// Traverse the slices in BFS fashion.
148+
SmallVector<Operation *> candidateSlices = candidates.front();
149+
candidates.pop();
203150

204151
FailureOr<scf::SCFFuseConsumerOfSliceResult> fusedResult =
205-
mlir::scf::tileAndFuseConsumerOfSlices(rewriter, entry.slices, loops);
152+
mlir::scf::tileAndFuseConsumerOfSlices(rewriter, candidateSlices,
153+
loops);
206154
if (failed(fusedResult)) {
207155
return failure();
208156
}

compiler/src/iree/compiler/Codegen/Common/TileAndFuseUtils.h

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,12 @@ void fuseProducersOfSlices(RewriterBase &rewriter,
3131
void collectTiledAndFusedOps(Operation *rootOp,
3232
llvm::SmallDenseSet<Operation *> &result);
3333

34-
/// Fuse all consumers of the given `tiledOps` into the surrounding `scf.forall`
35-
/// unless specified otherwise by `filterFn`. Returns a list of new
36-
/// `tensor.extract_slice` ops with new fusion opportunities.
34+
/// Fuse all consumers of the given `tiledOp` into the surrounding `scf.forall`.
35+
/// Returns a list of new `tensor.extract_slice` ops with new fusion
36+
/// opportunities, as well as the new surrounding `scf.forall` (because consumer
37+
/// fusion replaces the loop).
3738
FailureOr<std::queue<Operation *>> fuseConsumersIntoForall(
38-
RewriterBase &rewriter, ArrayRef<Operation *> tiledOps,
39+
RewriterBase &rewriter, Operation *tiledOp,
3940
MutableArrayRef<LoopLikeOpInterface> loops,
4041
std::function<bool(Operation *)> filterFn = [](Operation *) {
4142
return true;

compiler/src/iree/compiler/Codegen/Common/TileDispatchUsingForall.cpp

Lines changed: 9 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -210,18 +210,6 @@ static bool verifyComputeOpsAfterDistribution(FunctionOpInterface funcOp) {
210210
// Pass implementation.
211211
//===---------------------------------------------------------------------===//
212212

213-
/// Returns true if any value produced by `producer` is used as an init value
214-
/// for the DPS `user`. Returns false if the user is not in DPS.
215-
static bool isUsedAsInit(Operation *producer, Operation *user) {
216-
auto dpsIface = dyn_cast<DestinationStyleOpInterface>(user);
217-
if (!dpsIface)
218-
return false;
219-
ValueRange results = producer->getResults();
220-
return llvm::any_of(dpsIface.getDpsInits(), [&](Value operand) {
221-
return llvm::is_contained(results, operand);
222-
});
223-
}
224-
225213
void TileAndDistributeToWorkgroupsUsingForallOpPass::runOnOperation() {
226214
auto funcOp = getOperation();
227215
auto *context = &getContext();
@@ -244,14 +232,10 @@ void TileAndDistributeToWorkgroupsUsingForallOpPass::runOnOperation() {
244232

245233
llvm::DenseSet<Operation *> yieldReplacementsFor;
246234
for (auto op : tiledAndFusedOps) {
247-
// Require replacement for values that are used after the main tilable op or
248-
// by ops that will definitely not be fused. Note that if a value is used as
249-
// an init of a DPS op, the user currently cannot be fused. Having a
250-
// replacement for it would attempt fusion and fail, so avoid such cases.
235+
// If tiledAndFused ops doesn't contain the user; add an replacement
236+
// for that.
251237
if (llvm::any_of(op->getUsers(), [&](Operation *user) {
252-
if (isUsedAsInit(op, user))
253-
return false;
254-
return dominanceInfo.properlyDominates(tilableOp, user) ||
238+
return dominanceInfo.properlyDominates(tilableOp, user) &&
255239
!tiledAndFusedOps.contains(user);
256240
})) {
257241
yieldReplacementsFor.insert(op);
@@ -333,18 +317,16 @@ void TileAndDistributeToWorkgroupsUsingForallOpPass::runOnOperation() {
333317
});
334318
}
335319
std::swap(tileAndFuseResult->loops, tilingLoops);
336-
320+
Operation *rootTiledOp = tileAndFuseResult->tiledAndFusedOps.front();
337321
FailureOr<std::queue<Operation *>> newFusionOpportunities =
338-
fuseConsumersIntoForall(
339-
rewriter, tileAndFuseResult->tiledAndFusedOps.getArrayRef(),
340-
tilingLoops, [&tiledAndFusedOps](Operation *op) {
341-
return tiledAndFusedOps.contains(op);
342-
});
322+
fuseConsumersIntoForall(rewriter, rootTiledOp, tilingLoops,
323+
[&tiledAndFusedOps](Operation *op) {
324+
return tiledAndFusedOps.contains(op);
325+
});
343326
if (failed(newFusionOpportunities)) {
344327
// Continue the work if the failure is allowed.
345328
if (!verifyComputeOpsAfterDistribution(funcOp)) {
346-
tileAndFuseResult->tiledAndFusedOps.front()->emitOpError(
347-
"failed to fuse consumers");
329+
rootTiledOp->emitOpError("failed to fuse consumers");
348330
return signalPassFailure();
349331
}
350332
} else {

compiler/src/iree/compiler/Codegen/Common/test/tile_and_distribute_workgroups_using_forall.mlir

Lines changed: 0 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -1083,107 +1083,3 @@ func.func @infusible_pack(%arg0 : tensor<30xf32>) -> tensor<5x6xf32> {
10831083
// CHECK: linalg.generic
10841084
// CHECK: scf.forall.in_parallel {
10851085
// CHECK: linalg.pack
1086-
1087-
// -----
1088-
1089-
// Adapted from layer normalization. The graph structure is as follows
1090-
//
1091-
// %14
1092-
// / | \
1093-
// / %15 %17
1094-
// | | / |
1095-
// | [%19] |
1096-
// %21 | %22
1097-
// | | |
1098-
// v v v
1099-
//
1100-
// In particular, %21 and %22 are not users of the "main" tilable
1101-
// operation but we still want them to be fused. %19, %21 and %22
1102-
// all produce results returned from the function.
1103-
//
1104-
// Check that everything is fused and that there are three results
1105-
// from the loop being produced and returned.
1106-
//
1107-
// CHECK-LABEL: @multi_result_consumer_fusion
1108-
// CHECK-NOT: linalg.generic
1109-
// CHECK: %[[LOOP:.+]]:3 = scf.forall (%[[I:.+]], %[[J:.+]]) in (16, 256) shared_outs(%[[OUT0:.+]] = %{{.+}}, %[[OUT1:.+]] = %{{.+}}, %[[OUT2:.+]] = %{{.+}})
1110-
// CHECK: %[[v14:.+]] = linalg.generic
1111-
// CHECK: arith.divf
1112-
// CHECK: %[[v15:.+]] = linalg.generic
1113-
// CHECK: arith.subf
1114-
// CHECK: %[[v17:.+]] = linalg.generic
1115-
// CHECK: arith.divf
1116-
// CHECK: math.rsqrt
1117-
// CHECK: %[[RES0:.+]] = linalg.generic
1118-
// CHECK: arith.mulf
1119-
// CHECK: arith.extf
1120-
// CHECK: arith.mulf
1121-
// CHECK: arith.extf
1122-
// CHECK: arith.addf
1123-
// CHECK: arith.truncf
1124-
// CHECK: %[[RES1:.+]] = linalg.generic {{.*}} ins(%[[v14]] :
1125-
// CHECK: arith.truncf
1126-
// CHECK: %[[RES2:.+]] = linalg.generic {{.*}} ins(%[[v17]] :
1127-
// CHECK: arith.truncf
1128-
// CHECK: scf.forall.in_parallel
1129-
// CHECK: tensor.parallel_insert_slice %[[RES0]] into %[[OUT0]]
1130-
// CHECK: tensor.parallel_insert_slice %[[RES1]] into %[[OUT1]]
1131-
// CHECK: tensor.parallel_insert_slice %[[RES2]] into %[[OUT2]]
1132-
// CHECK-NOT: linalg.generic
1133-
// CHECK: return %[[LOOP]]#0, %[[LOOP]]#1, %[[LOOP]]#2
1134-
func.func @multi_result_consumer_fusion(
1135-
%6: tensor<16x256x2048xbf16>,
1136-
%7: tensor<2048xbf16>,
1137-
%8: tensor<2048xbf16>,
1138-
%10: tensor<16x256x2048xf32>,
1139-
%13: tensor<16x256xf32>
1140-
) -> (
1141-
tensor<16x256x2048xbf16>,
1142-
tensor<16x256xbf16>,
1143-
tensor<16x256xbf16>
1144-
) {
1145-
%cst = arith.constant 0.000000e+00 : f32
1146-
%cst_0 = arith.constant 2.048000e+03 : f32
1147-
%c0 = arith.constant 0 : index
1148-
%9 = tensor.empty() : tensor<16x256x2048xf32>
1149-
%11 = tensor.empty() : tensor<16x256xf32>
1150-
%14 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%13 : tensor<16x256xf32>) outs(%11 : tensor<16x256xf32>) {
1151-
^bb0(%in: f32, %out: f32):
1152-
%23 = arith.divf %in, %cst_0 : f32
1153-
linalg.yield %23 : f32
1154-
} -> tensor<16x256xf32>
1155-
%15 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%10, %14 : tensor<16x256x2048xf32>, tensor<16x256xf32>) outs(%9 : tensor<16x256x2048xf32>) {
1156-
^bb0(%in: f32, %in_1: f32, %out: f32):
1157-
%23 = arith.subf %in, %in_1 : f32
1158-
linalg.yield %23 : f32
1159-
} -> tensor<16x256x2048xf32>
1160-
%17 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%14 : tensor<16x256xf32>) outs(%11 : tensor<16x256xf32>) {
1161-
^bb0(%in: f32, %out: f32):
1162-
%23 = arith.divf %in, %cst_0 : f32
1163-
%24 = math.rsqrt %23 : f32
1164-
linalg.yield %24 : f32
1165-
} -> tensor<16x256xf32>
1166-
%18 = tensor.empty() : tensor<16x256x2048xbf16>
1167-
%19 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%15, %17, %7, %8 : tensor<16x256x2048xf32>, tensor<16x256xf32>, tensor<2048xbf16>, tensor<2048xbf16>) outs(%18 : tensor<16x256x2048xbf16>) attrs = {lowering_config = #iree_gpu.lowering_config<{lane_basis = [[1, 1, 64], [0, 1, 2]], reduction = [0, 0, 256], subgroup_basis = [[1, 1, 1], [0, 1, 2]], thread = [0, 0, 4], workgroup = [1, 1, 0]}>} {
1168-
^bb0(%in: f32, %in_1: f32, %in_2: bf16, %in_3: bf16, %out: bf16):
1169-
%23 = arith.mulf %in, %in_1 : f32
1170-
%24 = arith.extf %in_2 : bf16 to f32
1171-
%25 = arith.mulf %23, %24 : f32
1172-
%26 = arith.extf %in_3 : bf16 to f32
1173-
%27 = arith.addf %25, %26 : f32
1174-
%28 = arith.truncf %27 : f32 to bf16
1175-
linalg.yield %28 : bf16
1176-
} -> tensor<16x256x2048xbf16>
1177-
%20 = tensor.empty() : tensor<16x256xbf16>
1178-
%21 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%14 : tensor<16x256xf32>) outs(%20 : tensor<16x256xbf16>) {
1179-
^bb0(%in: f32, %out: bf16):
1180-
%23 = arith.truncf %in : f32 to bf16
1181-
linalg.yield %23 : bf16
1182-
} -> tensor<16x256xbf16>
1183-
%22 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%17 : tensor<16x256xf32>) outs(%20 : tensor<16x256xbf16>) {
1184-
^bb0(%in: f32, %out: bf16):
1185-
%23 = arith.truncf %in : f32 to bf16
1186-
linalg.yield %23 : bf16
1187-
} -> tensor<16x256xbf16>
1188-
return %19, %21, %22 : tensor<16x256x2048xbf16>, tensor<16x256xbf16>, tensor<16x256xbf16>
1189-
}

0 commit comments

Comments
 (0)