Skip to content

Commit 8cbca59

Browse files
IanWood1MaheshRavishankar
authored andcommitted
[Dispatch Creation] Fuse bit-truncate ops with producers (iree-org#21346)
Ensures bit-truncate ops get fused with their producers during dispatch creation by preventing elementwise fusion with consumers that would cause them to no longer fuse with their producer. --------- Signed-off-by: MaheshRavishankar <[email protected]> Signed-off-by: Ian Wood <[email protected]> Co-authored-by: MaheshRavishankar <[email protected]> Signed-off-by: keshavvinayak01 <[email protected]>
1 parent 9f80663 commit 8cbca59

File tree

13 files changed

+211
-21
lines changed

13 files changed

+211
-21
lines changed

compiler/src/iree/compiler/DispatchCreation/ElementwiseOpFusion.cpp

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -138,8 +138,14 @@ void ElementwiseOpFusionPass::runOnOperation() {
138138
[&](OpOperand *fusedOperand) {
139139
Operation *producer = fusedOperand->get().getDefiningOp();
140140
Operation *consumer = fusedOperand->getOwner();
141+
if (!producer || !consumer) {
142+
return false;
143+
}
141144

142-
if (!IREE::Flow::isNonNullAndOutsideDispatch({producer, consumer})) {
145+
// If `intraDispatch` is false, make sure that the producer and consumer
146+
// are outside dispatch.
147+
if (!intraDispatch &&
148+
!IREE::Flow::isNonNullAndOutsideDispatch({producer, consumer})) {
143149
return false;
144150
}
145151

@@ -158,11 +164,14 @@ void ElementwiseOpFusionPass::runOnOperation() {
158164
if (operands.size() >= kIreeMaxOperandCount)
159165
return false;
160166

161-
return areFusableAsElementwiseOps(context, fusedOperand,
162-
fuseMultiReduction);
167+
ElementwiseOpsFusabilityOptions options;
168+
options.fuseMultiReduction = fuseMultiReduction;
169+
options.fuseTruncateOps = fuseTruncateOps;
170+
return areFusableAsElementwiseOps(context, fusedOperand, options);
163171
};
164172

165173
RewritePatternSet linalgFusionPatterns(context);
174+
linalgFusionPatterns.insert<GatherFusionPattern>(context);
166175
linalg::populateElementwiseOpsFusionPatterns(linalgFusionPatterns,
167176
fuseElementwiseOpsControlFn);
168177

@@ -185,7 +194,6 @@ void ElementwiseOpFusionPass::runOnOperation() {
185194
RewritePatternSet linalgExtFusionPatterns(context);
186195
IREE::LinalgExt::populateFuseLinalgExtOpsWithTransposes(
187196
linalgExtFusionPatterns, foldTransposeControlFn);
188-
linalgExtFusionPatterns.insert<GatherFusionPattern>(context);
189197
if (failed(applyPatternsGreedily(
190198
getOperation(), std::move(linalgExtFusionPatterns), rewriteConfig))) {
191199
getOperation()->emitOpError(

compiler/src/iree/compiler/DispatchCreation/FuseHorizontalContractions.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,16 @@ static bool checkContractionOpEquivalence(MLIRContext *context, Operation *aOp,
191191
}
192192
}
193193

194+
// TODO(#20116): hack to prevent codegen failure for small horizontally fused
195+
// matmuls that go down LLVMGPUDistribute.
196+
unsigned mDimsSize = 1;
197+
for (unsigned dim : aContractionDims.value().m) {
198+
mDimsSize *= aStaticDims[dim];
199+
}
200+
if (mDimsSize < 16) {
201+
return false;
202+
}
203+
194204
auto checkSameRankAndElementType = [](Value aVal, Value bVal) {
195205
auto aType = dyn_cast<ShapedType>(aVal.getType());
196206
auto bType = dyn_cast<ShapedType>(bVal.getType());

compiler/src/iree/compiler/DispatchCreation/FuseMultiUseElementwiseProducer.cpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -197,13 +197,19 @@ static FailureOr<unsigned> fuseMultiUseProducers(Operation *funcOp,
197197
continue;
198198
}
199199

200-
// 7. Skip dequantization-like `producer` ops as we would rather fuse
200+
// 7. Skip bit-extend-like `producer` ops as we would rather fuse
201201
// by cloning the producer instead of multi-use fusion.
202202
if (IREE::LinalgExt::isBitExtendOp(producer)) {
203203
return;
204204
}
205205

206-
// 8. All uses from `producer` -> `consumer` need to be fusable.
206+
// 8. Skip bit-truncate-like `producer` ops as we would rather fuse
207+
// these operations with their producers.
208+
if (IREE::LinalgExt::isBitTruncateOp(producer)) {
209+
return;
210+
}
211+
212+
// 9. All uses from `producer` -> `consumer` need to be fusable.
207213
// Without this the `producer` is still live, and there is no
208214
// advantage to do the fusion.
209215
if (llvm::any_of(getAllUsesInConsumer(producer, genericOp),

compiler/src/iree/compiler/DispatchCreation/FusionUtils.cpp

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
namespace mlir::iree_compiler::DispatchCreation {
1919

2020
bool areFusableAsElementwiseOps(MLIRContext *context, OpOperand *fusedOperand,
21-
bool fuseMultiReduction) {
21+
ElementwiseOpsFusabilityOptions options) {
2222
Operation *producerOp = fusedOperand->get().getDefiningOp();
2323
Operation *consumerOp = fusedOperand->getOwner();
2424
if (!producerOp)
@@ -76,6 +76,26 @@ bool areFusableAsElementwiseOps(MLIRContext *context, OpOperand *fusedOperand,
7676
if (!linalgConsumerOp) {
7777
return false;
7878
}
79+
80+
if (!options.fuseTruncateOps &&
81+
IREE::LinalgExt::isBitTruncateOp(producerOp)) {
82+
// Do not fuse with bit-truncate-like operations with their consumers
83+
// unless:
84+
//
85+
// 1. The consumer has only one ins operand and is an elementwise
86+
// operation. The elementwise operation implies that the `outs` operand is
87+
// not real usage (and is typically a `tensor.empty`), so the core condition
88+
// is that there is only one "real" operand of the consumer.
89+
//
90+
// 2. The consumer is also a truncate (e.g. trunc from f32 to f16 to f8).
91+
bool isUnaryElementwise = linalgConsumerOp.getNumLoops() ==
92+
linalgConsumerOp.getNumParallelLoops() &&
93+
linalgConsumerOp.getNumDpsInputs() == 1;
94+
if (!IREE::LinalgExt::isBitTruncateOp(consumerOp) && !isUnaryElementwise) {
95+
return false;
96+
}
97+
}
98+
7999
// If the producer has a single use (this op), only fuse if
80100
// - 1) The consumer op is all parallel loops. The parallelism of the consumer
81101
// can be used as a way to amortize cost of redundant computation
@@ -89,7 +109,8 @@ bool areFusableAsElementwiseOps(MLIRContext *context, OpOperand *fusedOperand,
89109
.isPermutation()) {
90110
return false;
91111
}
92-
if (!fuseMultiReduction && linalgConsumerOp.getNumReductionLoops() != 1) {
112+
if (!options.fuseMultiReduction &&
113+
linalgConsumerOp.getNumReductionLoops() != 1) {
93114
return false;
94115
}
95116
if (linalg::isaContractionOpInterface(linalgConsumerOp) ||

compiler/src/iree/compiler/DispatchCreation/FusionUtils.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,14 @@ namespace mlir::iree_compiler::DispatchCreation {
1818

1919
/// Return true of the producer and consumer of `operand` are fusable
2020
/// using elementwise op fusion transformation.
21+
struct ElementwiseOpsFusabilityOptions {
22+
// Control fusion with consumer that has multiple reduction dimensions.
23+
bool fuseMultiReduction = false;
24+
// Control fusion with producer that is a truncate-like operation.
25+
bool fuseTruncateOps = false;
26+
};
2127
bool areFusableAsElementwiseOps(MLIRContext *context, OpOperand *operand,
22-
bool fuseMultiReduction);
28+
ElementwiseOpsFusabilityOptions options);
2329

2430
/// Move the definition of operands of `operations` before `insertionPoint`.
2531
LogicalResult moveOperandDefs(RewriterBase &rewriter,

compiler/src/iree/compiler/DispatchCreation/Passes.cpp

Lines changed: 49 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,13 @@ static llvm::cl::opt<bool> clEnableElementWiseFuseMultiReduction(
3333
llvm::cl::desc("Enable element-wise fusion of multi-reduction loop ops."),
3434
llvm::cl::init(true));
3535

36+
static llvm::cl::opt<bool> clEnableEarlyTruncFusion(
37+
"iree-dispatch-creation-enable-early-trunc-fusion",
38+
llvm::cl::desc(
39+
"Enable element-wise fusion of bit-truncate operation with their "
40+
"consumers before forming dispatch regions"),
41+
llvm::cl::init(false));
42+
3643
static llvm::cl::opt<bool> clEnableFusePaddingIntoLinalgConsumerOps(
3744
"iree-dispatch-creation-enable-fuse-padding-into-linalg-consumer-ops",
3845
llvm::cl::desc("Enable fusing tensor.pad ops into Linalg consumer ops."),
@@ -120,7 +127,9 @@ static void addDispatchRegionCreationPreprocessingPasses(
120127
.addPass([]() {
121128
return DispatchCreation::createElementwiseOpFusionPass(
122129
ElementwiseOpFusionPassOptions{
123-
clEnableElementWiseFuseMultiReduction});
130+
/*intraDispatch=*/false,
131+
/*fuseMultiReduction=*/clEnableElementWiseFuseMultiReduction,
132+
/*fuseTruncateOps=*/clEnableEarlyTruncFusion});
124133
})
125134
.addPass(IREE::Flow::createCanonicalizePass)
126135
.addPass(mlir::createCSEPass)
@@ -137,7 +146,9 @@ static void addDispatchRegionCreationPreprocessingPasses(
137146
.addPass([]() {
138147
return DispatchCreation::createElementwiseOpFusionPass(
139148
ElementwiseOpFusionPassOptions{
140-
clEnableElementWiseFuseMultiReduction});
149+
/*intraDispatch=*/false,
150+
/*fuseMultiReduction=*/clEnableElementWiseFuseMultiReduction,
151+
/*fuseTruncateOps=*/clEnableEarlyTruncFusion});
141152
})
142153
.addPass(IREE::Flow::createCanonicalizePass)
143154
.addPass(mlir::createCSEPass)
@@ -217,7 +228,14 @@ addDispatchRegionCreationPasses(OpPassManager &passManager,
217228
clEnableFusePaddingIntoLinalgConsumerOps,
218229
clEnableFusePaddingIntoLinalgProducerOps});
219230
})
220-
// Clone all producers into the dispatch region to prepare for being
231+
// Elementwise fuse operations that are iside a dispatch if possible.
232+
.addPass([&]() {
233+
return DispatchCreation::createElementwiseOpFusionPass(
234+
ElementwiseOpFusionPassOptions{/*intraDispatch=*/true,
235+
/*fuseMultiReduction=*/false,
236+
/*fuseTruncateOps=*/true});
237+
})
238+
// Clone all producers into the dispatch region to perpare for being
221239
// isolated from above. This enables running additional transformations
222240
// afterwards that would need the full dispatch content but don't want to
223241
// handle explicit captures as materialized as dispatch workgroup operands
@@ -358,12 +376,34 @@ void registerDispatchCreationPasses() {
358376
}
359377

360378
void registerDispatchCreationPipelines() {
361-
PassPipelineRegistration<TransformOptions> dispatchCreationPipeline(
362-
"iree-dispatch-creation-pipeline",
363-
"Flag used to run passes that form dispatch regions",
364-
[](OpPassManager &passManager, const TransformOptions &transformOptions) {
365-
buildDispatchCreationPassPipeline(passManager, transformOptions);
366-
});
379+
380+
/// Helper struct when registering pass pipeline options.
381+
struct DispatchCreationPipelineOptions
382+
: public PassPipelineOptions<DispatchCreationPipelineOptions> {
383+
Option<bool> aggressiveFusion{
384+
*this,
385+
"aggressive-fusion",
386+
llvm::cl::desc(
387+
"Enable aggressive fusion for dispatch creation pipeline"),
388+
llvm::cl::init(false),
389+
};
390+
391+
TransformOptions toTransformOptions() const {
392+
DispatchCreationOptions options;
393+
options.enableAggressiveFusion = aggressiveFusion;
394+
return TransformOptions{.options = options};
395+
}
396+
};
397+
398+
PassPipelineRegistration<DispatchCreationPipelineOptions>
399+
dispatchCreationPipeline(
400+
"iree-dispatch-creation-pipeline",
401+
"Flag used to run passes that form dispatch regions",
402+
[](OpPassManager &passManager,
403+
const DispatchCreationPipelineOptions &options) {
404+
buildDispatchCreationPassPipeline(passManager,
405+
options.toTransformOptions());
406+
});
367407

368408
PassPipelineRegistration<TransformOptions>
369409
dispatchCreationPreprocessingPipeline(

compiler/src/iree/compiler/DispatchCreation/Passes.td

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,12 @@ def ElementwiseOpFusionPass :
5050
Pass<"iree-dispatch-creation-elementwise-op-fusion", ""> {
5151
let summary = "Fuse elementwise operations.";
5252
let options = [
53+
Option<"intraDispatch", "intra-dispatch", "bool",
54+
/*default=*/"false", "Fuse operations within a dispatch only (default is to fuse only operations outside of a dispatch)">,
5355
Option<"fuseMultiReduction", "fuse-multi-reduction", "bool",
54-
/*default=*/"true", "Fuse ops that have multiple reduction iterators">
56+
/*default=*/"true", "Fuse ops that have multiple reduction iterators">,
57+
Option<"fuseTruncateOps", "fuse-truncate-ops", "bool",
58+
/*default=*/"false", "Fuse producer truncate-like operations with consumers">,
5559
];
5660
let dependentDialects = [
5761
"mlir::affine::AffineDialect",

compiler/src/iree/compiler/DispatchCreation/test/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ iree_lit_test_suite(
4646
"pad_fusion_with_producer.mlir",
4747
"pipeline_tests.mlir",
4848
"propagate_encodings.mlir",
49+
"pipeline_tests_aggressive.mlir",
4950
"set_encoding.mlir",
5051
"set_encoding_padding.mlir",
5152
"set_encoding_pipeline.mlir",

compiler/src/iree/compiler/DispatchCreation/test/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ iree_lit_test_suite(
4343
"pad_fusion_with_consumer.mlir"
4444
"pad_fusion_with_producer.mlir"
4545
"pipeline_tests.mlir"
46+
"pipeline_tests_aggressive.mlir"
4647
"propagate_encodings.mlir"
4748
"set_encoding.mlir"
4849
"set_encoding_padding.mlir"

compiler/src/iree/compiler/DispatchCreation/test/dispatch_linalg_on_tensors_default.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ util.func @mixed_conv(%arg0 : tensor<2x130x130x16xf16>, %arg1 : tensor<3x3x16x32
113113
util.return %truncf : tensor<2x128x128x320xf16>
114114
}
115115
// CHECK-LABEL: func public @mixed_conv(
116-
// CHECK: flow.dispatch.workgroups
116+
// CHECK: %[[DISPATCH0:.+]] = flow.dispatch.workgroups
117117
// CHECK: %[[FILL:.+]] = linalg.fill
118118
// CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf
119119
// CHECK-SAME: outs(%[[FILL]] :

0 commit comments

Comments
 (0)