Skip to content

Commit c2f1fb9

Browse files
authored
[WIP] Expose multi-use fusion flag to pipeline options. (#21400)
Adds flag (default to enable) that can be used to optionally disable multi-use fusion. Fixes failed scatter fusion in llama 405b. --------- Signed-off-by: Ian Wood <[email protected]>
1 parent 9e8691f commit c2f1fb9

File tree

3 files changed

+23
-15
lines changed

3 files changed

+23
-15
lines changed

compiler/src/iree/compiler/DispatchCreation/Passes.cpp

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -54,11 +54,6 @@ static llvm::cl::opt<bool> clEnableFuseHorizontalContractions(
5454
"Enables horizontal fusion of contractions with one common operand"),
5555
llvm::cl::init(false));
5656

57-
static llvm::cl::opt<bool>
58-
clEnableFuseMultiUse("iree-dispatch-creation-fuse-multi-use",
59-
llvm::cl::desc("Fuse multi-use ops."),
60-
llvm::cl::init(false));
61-
6257
static llvm::cl::opt<bool> clEnableDataTiling(
6358
"iree-dispatch-creation-experimental-data-tiling",
6459
llvm::cl::desc("Enable data-tiling at flow level, i.e., it sets encodings "
@@ -116,7 +111,9 @@ static void addCleanupPatterns(OpPassManager &passManager) {
116111
// Pipelines
117112
//===----------------------------------------------------------------------===//
118113

119-
void addDispatchRegionCreationPreprocessingPasses(OpPassManager &passManager) {
114+
static void addDispatchRegionCreationPreprocessingPasses(
115+
OpPassManager &passManager,
116+
const DispatchCreationOptions &dispatchOptions) {
120117
// 1. Do some simple elementwise op fusion. This could be skipped,
121118
// but could reduce the surface area of ops to handle later.
122119
FunctionLikeNest(passManager)
@@ -161,7 +158,11 @@ void addDispatchRegionCreationPreprocessingPasses(OpPassManager &passManager) {
161158
FunctionLikeNest(passManager)
162159
// 5. After all the reshape propagations, fuse elementwise operations
163160
// even if the producer has multiple uses.
164-
.addPass(DispatchCreation::createFuseMultiUseElementwiseProducerPass)
161+
.addPredicatedPass(dispatchOptions.enableFuseMultiUse,
162+
[&]() {
163+
return DispatchCreation::
164+
createFuseMultiUseElementwiseProducerPass();
165+
})
165166

166167
// 6. Some more "post elementwise fusion passes".
167168
// a. Detensorize.
@@ -312,7 +313,8 @@ void buildDispatchCreationPassPipeline(
312313
.addPass(IREE::Flow::createCanonicalizePass)
313314
.addPass(mlir::createCSEPass);
314315

315-
addDispatchRegionCreationPreprocessingPasses(passManager);
316+
addDispatchRegionCreationPreprocessingPasses(passManager,
317+
transformOptions.options);
316318
addDispatchRegionCreationPasses(passManager, transformOptions.options);
317319

318320
FunctionLikeNest(passManager)
@@ -363,13 +365,16 @@ void registerDispatchCreationPipelines() {
363365
buildDispatchCreationPassPipeline(passManager, transformOptions);
364366
});
365367

366-
PassPipelineRegistration<> dispatchCreationPreprocessingPipeline(
367-
"iree-dispatch-creation-preprocessing-pipeline",
368-
"Flag used to run preprocessing passes that run passes before dispatch "
369-
"region formation. Used only for testing",
370-
[](OpPassManager &passManager) {
371-
addDispatchRegionCreationPreprocessingPasses(passManager);
372-
});
368+
PassPipelineRegistration<TransformOptions>
369+
dispatchCreationPreprocessingPipeline(
370+
"iree-dispatch-creation-preprocessing-pipeline",
371+
"Flag used to run preprocessing passes that run passes before "
372+
"dispatch region formation. Used only for testing",
373+
[](OpPassManager &passManager,
374+
const TransformOptions &transformOptions) {
375+
addDispatchRegionCreationPreprocessingPasses(
376+
passManager, transformOptions.options);
377+
});
373378
}
374379

375380
} // namespace mlir::iree_compiler::DispatchCreation

compiler/src/iree/compiler/Pipelines/Options.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -310,6 +310,8 @@ void DispatchCreationOptions::bindOptions(OptionsBinder &binder) {
310310
init_at_opt(llvm::OptimizationLevel::O2, true)},
311311
llvm::cl::desc("Aggressive fusion opportunities that are behind a flag "
312312
"since all backends dont support it yet"));
313+
binder.opt<bool>("iree-dispatch-creation-fuse-multi-use", enableFuseMultiUse,
314+
llvm::cl::desc("Fuse operations with multiple uses."));
313315
}
314316

315317
} // namespace mlir::iree_compiler

compiler/src/iree/compiler/Pipelines/Options.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,7 @@ struct DispatchCreationOptions {
218218
llvm::OptimizationLevel optLevel;
219219

220220
bool enableAggressiveFusion = false;
221+
bool enableFuseMultiUse = true;
221222

222223
void bindOptions(OptionsBinder &binder);
223224
using FromFlags = OptionsFromFlags<DispatchCreationOptions>;

0 commit comments

Comments
 (0)