Skip to content

Commit c0c9b69

Browse files
authored
[Dispatch Creation] Don't add unfusable consumers to fusion group (#22461)
There are cases where two ops can't be fused as it would result in use-def violation (see the diagram below). Currently, they are still placed in the same fusion group. This means that the consumer gets marked by the analysis as being in a fusion group causing it to potentially miss out on fusion opportunities. This change moves this check so that it occurs during the analysis and modifies the old check to error out. ``` A (in fusion group) | \ | \ | v | B (unfusable consumer of A) | / | / v v C (trying to fuse with A) ``` --------- Signed-off-by: Ian Wood <[email protected]>
1 parent dca3747 commit c0c9b69

File tree

2 files changed

+109
-6
lines changed

2 files changed

+109
-6
lines changed

compiler/src/iree/compiler/DispatchCreation/FormDispatchRegions.cpp

Lines changed: 59 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include "llvm/ADT/TypeSwitch.h"
2121
#include "llvm/Support/Casting.h"
2222
#include "llvm/Support/Debug.h"
23+
#include "mlir/Analysis/SliceAnalysis.h"
2324
#include "mlir/Analysis/TopologicalSortUtils.h"
2425
#include "mlir/Dialect/Affine/IR/AffineOps.h"
2526
#include "mlir/Dialect/Arith/IR/Arith.h"
@@ -43,6 +44,7 @@
4344
#include "mlir/Pass/Pass.h"
4445
#include "mlir/Support/LLVM.h"
4546
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
47+
#include "mlir/Transforms/RegionUtils.h"
4648

4749
#define DEBUG_TYPE "iree-dispatch-creation-form-dispatch-regions"
4850

@@ -145,6 +147,50 @@ class FusionGroup {
145147
// Insert `op` into the fusion group.
146148
void insert(Operation *op);
147149

150+
/// Returns true if `consumerOp` has a transitive dependency on the fusion
151+
/// group. This means that some transitive dependency of `consumerOp` (not in
152+
/// the fusion group) itself uses an operation in the fusion group. This is
153+
/// required for fusion because it must be legal to take a program slice that
154+
/// contains only the ops in the fusion group.
155+
bool
156+
hasTransitiveDependencyOnFusionGroup(Operation *consumerOp,
157+
DominanceInfo const &dominance) const {
158+
BackwardSliceOptions options;
159+
options.inclusive = true;
160+
options.omitUsesFromAbove = false;
161+
options.omitBlockArguments = true;
162+
options.filter = [&](Operation *sliceBoundaryOp) {
163+
return !llvm::all_of(
164+
loopMaps.getArrayRef(), [&](std::pair<Operation *, AffineMap> pair) {
165+
return dominance.properlyDominates(sliceBoundaryOp, pair.first);
166+
});
167+
};
168+
169+
llvm::SetVector<Operation *> slice;
170+
auto populateSlice = [&](OpOperand *operand) {
171+
// It's okay if the consumer directly uses an operation in the fusion
172+
// group.
173+
if (loopMaps.contains(operand->get().getDefiningOp())) {
174+
return;
175+
}
176+
LogicalResult result = getBackwardSlice(operand->get(), &slice, options);
177+
assert(result.succeeded() && "expected a backward slice");
178+
(void)result;
179+
};
180+
181+
// Search all of the operands op `consumerOp` as well as all the values used
182+
// in its regions.
183+
mlir::visitUsedValuesDefinedAbove(consumerOp->getRegions(), populateSlice);
184+
for (OpOperand &operand : consumerOp->getOpOperands()) {
185+
populateSlice(&operand);
186+
}
187+
188+
return llvm::any_of(loopMaps.getArrayRef(),
189+
[&](std::pair<Operation *, AffineMap> pair) {
190+
return slice.contains(pair.first);
191+
});
192+
}
193+
148194
private:
149195
Operation *rootOp;
150196
// All operations to be fused with the root op. This does not include
@@ -435,6 +481,9 @@ getFusableUses(MLIRContext *context, Operation *op,
435481
if (isa<tensor::DimOp>(user)) {
436482
continue;
437483
}
484+
if (op->getBlock() != user->getBlock()) {
485+
continue;
486+
}
438487
fusableUses.insert(&use);
439488
}
440489

@@ -667,6 +716,13 @@ fuseRootsWithConsumers(MLIRContext *context, ArrayRef<Operation *> roots,
667716
continue;
668717
}
669718

719+
// Ensure that fusing the consumer would not cause use-def violations.
720+
if (tracker.getFusionGroup(currRoot)
721+
.hasTransitiveDependencyOnFusionGroup(fusableUse->getOwner(),
722+
dominanceInfo)) {
723+
continue;
724+
}
725+
670726
if (isFusableWithConsumer(*fusableUse, tracker, options)) {
671727
tracker.appendToFusionGroup(consumerOp, fusionGroup);
672728
workList.push_back(consumerOp);
@@ -957,7 +1013,8 @@ createFusionGroups(TensorDimTrackingRewriter &rewriter,
9571013
auto newRegionOp =
9581014
movePrecedingOpsIntoDispatchRegion(rewriter, producer, regionOp);
9591015
if (failed(newRegionOp)) {
960-
return producer->emitOpError("failed to move producer into region");
1016+
producer->emitWarning("failed to move producer into region");
1017+
continue;
9611018
}
9621019
regionOp = *newRegionOp;
9631020
}
@@ -974,7 +1031,7 @@ createFusionGroups(TensorDimTrackingRewriter &rewriter,
9741031
auto newRegionOp = IREE::Flow::moveFollowingOpIntoDispatchRegion(
9751032
rewriter, consumer, regionOp);
9761033
if (failed(newRegionOp)) {
977-
continue;
1034+
return consumer->emitOpError("failed to move consumer into region");
9781035
}
9791036
regionOp = *newRegionOp;
9801037
}

compiler/src/iree/compiler/DispatchCreation/test/form_dispatch_regions.mlir

Lines changed: 50 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1391,15 +1391,16 @@ util.func public @avoid_illegal_consumer_fusion(%arg0: tensor<75600x5120xbf16>)
13911391
util.return %6 : tensor<75600x1x5120xbf16>
13921392
}
13931393
// CHECK-LABEL: @avoid_illegal_consumer_fusion(
1394-
// CHECK: %[[DISPATCH:.+]]:2 = flow.dispatch.region
1394+
// CHECK: %[[DISPATCH0:.+]]:2 = flow.dispatch.region
13951395
// CHECK: %[[GENERIC0:.+]] = linalg.generic
13961396
// CHECK: %[[GENERIC1:.+]] = linalg.generic
13971397
// CHECK-SAME: ins(%[[GENERIC0]] :
13981398
// CHECK: flow.return %[[GENERIC1]], %[[GENERIC0]]
1399-
// CHECK: %[[EXPAND_SHAPE:.+]] = tensor.expand_shape %[[DISPATCH]]#1
1399+
// CHECK: %[[EXPAND_SHAPE:.+]] = tensor.expand_shape %[[DISPATCH0]]#1
1400+
// CHECK: %[[DISPATCH1:.+]] = flow.dispatch.region
14001401
// CHECK: %[[GENERIC2:.+]] = linalg.generic
1401-
// CHECK-SAME: ins(%[[EXPAND_SHAPE]], %[[DISPATCH]]#0 :
1402-
// CHECK: util.return %[[GENERIC2]]
1402+
// CHECK-SAME: ins(%[[EXPAND_SHAPE]], %[[DISPATCH0]]#0 :
1403+
// CHECK: util.return %[[DISPATCH1]]
14031404

14041405
// -----
14051406

@@ -1791,3 +1792,48 @@ util.func public @dont_fuse_producer_matmuls(%arg0 : tensor<4x4x7xf32>, %arg1 :
17911792
// CHECK: %[[OP:.+]] = linalg.generic
17921793
// CHECK-SAME: ins(%[[DISPATCH0]]
17931794
// CHECK: flow.return %[[OP]]
1795+
1796+
// -----
1797+
1798+
util.func public @no_fusion_across_blocks(%arg0: tensor<3x2xf32>) -> tensor<f32> {
1799+
%cst = arith.constant 0.000000e+00 : f32
1800+
%0 = tensor.empty() : tensor<f32>
1801+
%1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<f32>) -> tensor<f32>
1802+
%2 = tensor.empty() : tensor<3x2xf32>
1803+
%4 = linalg.generic {
1804+
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
1805+
affine_map<(d0, d1) -> ()>],
1806+
iterator_types = ["reduction", "reduction"]}
1807+
ins(%arg0 : tensor<3x2xf32>) outs(%1 : tensor<f32>) {
1808+
^bb0(%in: f32, %out: f32):
1809+
%9 = arith.addf %in, %out : f32
1810+
linalg.yield %9 : f32
1811+
} -> tensor<f32>
1812+
// Dispatch region uses the reduction result
1813+
%5 = flow.dispatch.region -> (tensor<f32>) {
1814+
%9 = linalg.generic {
1815+
indexing_maps = [affine_map<() -> ()>,
1816+
affine_map<() -> ()>,
1817+
affine_map<() -> ()>],
1818+
iterator_types = []}
1819+
ins(%4, %1 : tensor<f32>, tensor<f32>) outs(%0 : tensor<f32>) {
1820+
^bb0(%in: f32, %in_0: f32, %out: f32):
1821+
%10 = arith.divf %in, %in_0 : f32
1822+
linalg.yield %10 : f32
1823+
} -> tensor<f32>
1824+
flow.return %9 : tensor<f32>
1825+
}
1826+
util.return %5 : tensor<f32>
1827+
}
1828+
// CHECK-LABEL: util.func public @no_fusion_across_blocks
1829+
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<3x2xf32>
1830+
// CHECK: %[[FILL:.+]] = linalg.fill
1831+
// CHECK: %[[DISPATCH0:.+]] = flow.dispatch.region
1832+
// CHECK: %[[REDUCTION:.+]] = linalg.generic
1833+
// CHECK-SAME: ins(%[[ARG0]]
1834+
// CHECK: flow.return %[[REDUCTION]]
1835+
// CHECK: %[[DISPATCH1:.+]] = flow.dispatch.region
1836+
// CHECK: %[[DIV:.+]] = linalg.generic
1837+
// CHECK-SAME: ins(%[[DISPATCH0]], %[[FILL]]
1838+
// CHECK: flow.return %[[DIV]]
1839+
// CHECK: util.return %[[DISPATCH1]]

0 commit comments

Comments
 (0)