Skip to content

Commit 4451b8b

Browse files
authored
Add pipeline options for Dispatch Creation (#20217)
This change allows specifying optimization level defaults for flags used in dispatch creation. So far, I have only moved `iree-dispatch-creation-enable-aggressive-fusion` and didn't change the name as it would be a breaking change for a lot of users. Also, this cleans up no longer needed sdxl benchmark flags. --------- Signed-off-by: Ian Wood <[email protected]>
1 parent b36b355 commit 4451b8b

File tree

14 files changed

+65
-44
lines changed

14 files changed

+65
-44
lines changed

compiler/src/iree/compiler/API/Internal/CompilerDriver.cpp

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,7 @@ struct GlobalInit {
241241
InputDialectOptions *clInputOptions = nullptr;
242242
PreprocessingOptions *clPreprocessingOptions = nullptr;
243243
GlobalOptimizationOptions *clGlobalOptimizationOptions = nullptr;
244+
DispatchCreationOptions *clDispatchCreationOptions = nullptr;
244245
SchedulingOptions *clSchedulingOptions = nullptr;
245246
IREE::HAL::TargetOptions *clHalTargetOptions = nullptr;
246247
IREE::VM::TargetOptions *clVmTargetOptions = nullptr;
@@ -286,6 +287,7 @@ void GlobalInit::registerCommandLineOptions() {
286287
clInputOptions = &InputDialectOptions::FromFlags::get();
287288
clPreprocessingOptions = &PreprocessingOptions::FromFlags::get();
288289
clGlobalOptimizationOptions = &GlobalOptimizationOptions::FromFlags::get();
290+
clDispatchCreationOptions = &DispatchCreationOptions::FromFlags::get();
289291
clSchedulingOptions = &SchedulingOptions::FromFlags::get();
290292
clHalTargetOptions = &IREE::HAL::TargetOptions::FromFlags::get();
291293
clVmTargetOptions = &IREE::VM::TargetOptions::FromFlags::get();
@@ -396,6 +398,7 @@ struct Session {
396398
InputDialectOptions inputOptions;
397399
PreprocessingOptions preprocessingOptions;
398400
GlobalOptimizationOptions highLevelOptimizationOptions;
401+
DispatchCreationOptions dispatchCreationOptions;
399402
SchedulingOptions schedulingOptions;
400403
IREE::HAL::TargetOptions halTargetOptions;
401404
IREE::VM::TargetOptions vmTargetOptions;
@@ -423,6 +426,7 @@ Session::Session(GlobalInit &globalInit)
423426
inputOptions = *globalInit.clInputOptions;
424427
preprocessingOptions = *globalInit.clPreprocessingOptions;
425428
highLevelOptimizationOptions = *globalInit.clGlobalOptimizationOptions;
429+
dispatchCreationOptions = *globalInit.clDispatchCreationOptions;
426430
schedulingOptions = *globalInit.clSchedulingOptions;
427431
halTargetOptions = *globalInit.clHalTargetOptions;
428432
vmTargetOptions = *globalInit.clVmTargetOptions;
@@ -443,6 +447,7 @@ Session::Session(GlobalInit &globalInit)
443447
preprocessingOptions.bindOptions(binder);
444448
inputOptions.bindOptions(binder);
445449
highLevelOptimizationOptions.bindOptions(binder);
450+
dispatchCreationOptions.bindOptions(binder);
446451
schedulingOptions.bindOptions(binder);
447452
halTargetOptions.bindOptions(binder);
448453
vmTargetOptions.bindOptions(binder);
@@ -982,9 +987,9 @@ bool Invocation::runPipeline(enum iree_compiler_pipeline_t pipeline) {
982987
buildIREEVMTransformPassPipeline(
983988
session.targetRegistry, session.bindingOptions, session.inputOptions,
984989
session.preprocessingOptions, session.highLevelOptimizationOptions,
985-
session.schedulingOptions, session.halTargetOptions,
986-
session.vmTargetOptions, pipelineHooks, *passManager, compileFrom,
987-
compileTo);
990+
session.dispatchCreationOptions, session.schedulingOptions,
991+
session.halTargetOptions, session.vmTargetOptions, pipelineHooks,
992+
*passManager, compileFrom, compileTo);
988993
break;
989994
}
990995
case IREE_COMPILER_PIPELINE_HAL_EXECUTABLE: {
@@ -1015,8 +1020,9 @@ bool Invocation::runPipeline(enum iree_compiler_pipeline_t pipeline) {
10151020
buildIREEPrecompileTransformPassPipeline(
10161021
session.targetRegistry, session.bindingOptions, session.inputOptions,
10171022
session.preprocessingOptions, session.highLevelOptimizationOptions,
1018-
session.schedulingOptions, session.halTargetOptions, pipelineHooks,
1019-
*passManager, compileFrom, compileTo);
1023+
session.dispatchCreationOptions, session.schedulingOptions,
1024+
session.halTargetOptions, pipelineHooks, *passManager, compileFrom,
1025+
compileTo);
10201026
break;
10211027
}
10221028
default:

compiler/src/iree/compiler/ConstEval/JitGlobals.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ struct CompileOptions {
7272
InputDialectOptions inputOptions;
7373
PreprocessingOptions preprocessingOptions;
7474
GlobalOptimizationOptions globalOptimizationOptions;
75+
DispatchCreationOptions dispatchCreationOptions;
7576
SchedulingOptions schedulingOptions;
7677
IREE::HAL::TargetOptions executableOptions;
7778
IREE::VM::TargetOptions targetOptions;
@@ -647,6 +648,7 @@ class JitGlobalsPass final : public impl::JitGlobalsPassBase<JitGlobalsPass> {
647648
*targetRegistry.value, compileOptions->bindingOptions,
648649
compileOptions->inputOptions, compileOptions->preprocessingOptions,
649650
compileOptions->globalOptimizationOptions,
651+
compileOptions->dispatchCreationOptions,
650652
compileOptions->schedulingOptions, compileOptions->executableOptions,
651653
compileOptions->targetOptions, compileOptions->hooks, compilePipeline);
652654
}

compiler/src/iree/compiler/DispatchCreation/Passes.cpp

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -70,12 +70,6 @@ static llvm::cl::opt<bool>
7070
llvm::cl::desc("Fuse multi-use ops."),
7171
llvm::cl::init(false));
7272

73-
static llvm::cl::opt<bool> clEnableAggressiveFusion(
74-
"iree-dispatch-creation-enable-aggressive-fusion",
75-
llvm::cl::desc("Aggressive fusion opportunities that are behind a flag "
76-
"since all backends dont support it yet"),
77-
llvm::cl::init(false));
78-
7973
static llvm::cl::opt<bool> clEnableDataTiling(
8074
"iree-dispatch-creation-experimental-data-tiling",
8175
llvm::cl::desc("Enable data-tiling at flow level, i.e., it sets encodings "
@@ -215,16 +209,18 @@ void addDispatchRegionCreationPreprocessingPasses(OpPassManager &passManager) {
215209
// Note that we should not hoist out small constants before the dispatch regions
216210
// are converted to workgroups. E.g., the `cseConstant` option needs to be false
217211
// in greedy pattern rewriting drivers.
218-
static void addDispatchRegionCreationPasses(OpPassManager &passManager) {
212+
static void
213+
addDispatchRegionCreationPasses(OpPassManager &passManager,
214+
const DispatchCreationOptions &options) {
219215
FunctionLikeNest(passManager)
220216
// Create dispatches for scalar operations as roots.
221217
.addPass(DispatchCreation::createFormScalarDispatchesPass)
222218
// Create `flow.dispatch.region` centered around a root and fuse with
223219
// producers and consumers.
224-
.addPass([] {
220+
.addPass([&] {
225221
return DispatchCreation::createFormDispatchRegionsPass(
226222
FormDispatchRegionsPassOptions{
227-
clEnableAggressiveFusion,
223+
options.enableAggressiveFusion,
228224
clEnableFusePaddingIntoLinalgConsumerOps,
229225
clEnableFusePaddingIntoLinalgProducerOps});
230226
})
@@ -233,10 +229,10 @@ static void addDispatchRegionCreationPasses(OpPassManager &passManager) {
233229
// afterwards that would need the full dispatch content but don't want to
234230
// handle explicit captures as materialized as dispatch workgroup operands
235231
// and block arguments.
236-
.addPass([] {
232+
.addPass([&] {
237233
return DispatchCreation::createCloneProducersIntoDispatchRegionsPass(
238234
CloneProducersIntoDispatchRegionsPassOptions{
239-
clEnableAggressiveFusion});
235+
options.enableAggressiveFusion});
240236
})
241237
// Collapse dimensions of linalg Ops.
242238
.addPass(DispatchCreation::createCollapseDimensionsPass);
@@ -306,7 +302,7 @@ void buildDispatchCreationPassPipeline(
306302
.addPass(mlir::createCSEPass);
307303

308304
addDispatchRegionCreationPreprocessingPasses(passManager);
309-
addDispatchRegionCreationPasses(passManager);
305+
addDispatchRegionCreationPasses(passManager, transformOptions.options);
310306

311307
FunctionLikeNest(passManager)
312308
.addPass(DispatchCreation::createConvertDispatchRegionsToWorkgroupsPass)

compiler/src/iree/compiler/DispatchCreation/Passes.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,9 @@ namespace mlir::iree_compiler::DispatchCreation {
2323

2424
/// This is a placeholder for future. We should pass all the options through the
2525
/// struct.
26-
struct TransformOptions : public PassPipelineOptions<TransformOptions> {};
26+
struct TransformOptions : public PassPipelineOptions<TransformOptions> {
27+
DispatchCreationOptions options;
28+
};
2729

2830
void buildDispatchCreationPassPipeline(
2931
OpPassManager &passManager, const TransformOptions &transformOptions);

compiler/src/iree/compiler/Pipelines/Options.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ IREE_DEFINE_COMPILER_OPTION_FLAGS(
1414
IREE_DEFINE_COMPILER_OPTION_FLAGS(mlir::iree_compiler::SchedulingOptions);
1515
IREE_DEFINE_COMPILER_OPTION_FLAGS(mlir::iree_compiler::PreprocessingOptions);
1616
IREE_DEFINE_COMPILER_OPTION_FLAGS(mlir::iree_compiler::GlobalPipelineOptions);
17+
IREE_DEFINE_COMPILER_OPTION_FLAGS(mlir::iree_compiler::DispatchCreationOptions);
1718

1819
namespace mlir::iree_compiler {
1920

@@ -293,4 +294,19 @@ void SchedulingOptions::bindOptions(OptionsBinder &binder) {
293294
llvm::cl::cat(category));
294295
}
295296

297+
void DispatchCreationOptions::bindOptions(OptionsBinder &binder) {
298+
static llvm::cl::OptionCategory category(
299+
"IREE options for controlling dispatch region creation.");
300+
auto init_at_opt = binder.optimizationLevel(
301+
"iree-dispatch-creation-opt-level", optLevel,
302+
llvm::cl::desc("Optimization level for the this pipeline"),
303+
llvm::cl::cat(category));
304+
binder.opt<bool>(
305+
"iree-dispatch-creation-enable-aggressive-fusion", enableAggressiveFusion,
306+
{init_at_opt(llvm::OptimizationLevel::O0, false),
307+
init_at_opt(llvm::OptimizationLevel::O2, true)},
308+
llvm::cl::desc("Aggressive fusion opportunities that are behind a flag "
309+
"since all backends dont support it yet"));
310+
}
311+
296312
} // namespace mlir::iree_compiler

compiler/src/iree/compiler/Pipelines/Options.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,15 @@ struct SchedulingOptions {
214214
using FromFlags = OptionsFromFlags<SchedulingOptions>;
215215
};
216216

217+
struct DispatchCreationOptions {
218+
llvm::OptimizationLevel optLevel;
219+
220+
bool enableAggressiveFusion = false;
221+
222+
void bindOptions(OptionsBinder &binder);
223+
using FromFlags = OptionsFromFlags<DispatchCreationOptions>;
224+
};
225+
217226
} // namespace mlir::iree_compiler
218227

219228
#endif // IREE_COMPILER_PIPELINES_OPTIONS_H_

compiler/src/iree/compiler/Pipelines/Pipelines.cpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ void buildIREEPrecompileTransformPassPipeline(
7777
BindingOptions bindingOptions, InputDialectOptions inputOptions,
7878
PreprocessingOptions preprocessingOptions,
7979
GlobalOptimizationOptions globalOptimizationOptions,
80+
DispatchCreationOptions dispatchCreationOptions,
8081
SchedulingOptions schedulingOptions,
8182
IREE::HAL::TargetOptions halTargetOptions, IREEVMPipelineHooks &hooks,
8283
OpPassManager &passManager, IREEVMPipelinePhase compileFrom,
@@ -247,15 +248,16 @@ void buildIREEVMTransformPassPipeline(
247248
BindingOptions bindingOptions, InputDialectOptions inputOptions,
248249
PreprocessingOptions preprocessingOptions,
249250
GlobalOptimizationOptions globalOptimizationOptions,
251+
DispatchCreationOptions dispatchCreationOptions,
250252
SchedulingOptions schedulingOptions,
251253
IREE::HAL::TargetOptions halTargetOptions,
252254
IREE::VM::TargetOptions vmTargetOptions, IREEVMPipelineHooks &hooks,
253255
OpPassManager &passManager, IREEVMPipelinePhase compileFrom,
254256
IREEVMPipelinePhase compileTo) {
255257
buildIREEPrecompileTransformPassPipeline(
256258
targetRegistry, bindingOptions, inputOptions, preprocessingOptions,
257-
globalOptimizationOptions, schedulingOptions, halTargetOptions, hooks,
258-
passManager, compileFrom, compileTo);
259+
globalOptimizationOptions, dispatchCreationOptions, schedulingOptions,
260+
halTargetOptions, hooks, passManager, compileFrom, compileTo);
259261

260262
if (compileTo <= IREEVMPipelinePhase::GlobalOptimization)
261263
return; // early-exit
@@ -274,13 +276,14 @@ void buildIREEVMTransformPassPipeline(
274276
// No flow/stream processing (implies no tensors).
275277
break;
276278
default:
277-
DispatchCreation::TransformOptions dispatchCreationOptions;
279+
DispatchCreation::TransformOptions dispatchTransformOptions;
280+
dispatchTransformOptions.options = dispatchCreationOptions;
278281
if (compileFrom < IREEVMPipelinePhase::DispatchCreation) { // late-entry
279282
IREE_TRACE_ADD_BEGIN_FRAME_PASS(passManager, "DispatchCreation");
280283
if (hooks.beforePhase)
281284
hooks.beforePhase(IREEVMPipelinePhase::DispatchCreation, passManager);
282285
DispatchCreation::buildDispatchCreationPassPipeline(
283-
passManager, dispatchCreationOptions);
286+
passManager, dispatchTransformOptions);
284287
if (hooks.afterPhase)
285288
hooks.afterPhase(IREEVMPipelinePhase::DispatchCreation, passManager);
286289
IREE_TRACE_ADD_END_FRAME_PASS(passManager, "DispatchCreation");
@@ -385,6 +388,7 @@ void buildDefaultIREEVMTransformPassPipeline(OpPassManager &passManager) {
385388
IREE::HAL::TargetRegistry::getGlobal(), BindingOptions::FromFlags::get(),
386389
InputDialectOptions::FromFlags::get(),
387390
PreprocessingOptions::FromFlags::get(), highLevelOptimizations,
391+
DispatchCreationOptions::FromFlags::get(),
388392
SchedulingOptions::FromFlags::get(),
389393
IREE::HAL::TargetOptions::FromFlags::get(),
390394
IREE::VM::TargetOptions::FromFlags::get(), defaultHooks, passManager);

compiler/src/iree/compiler/Pipelines/Pipelines.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ void buildIREEPrecompileTransformPassPipeline(
104104
BindingOptions bindingOptions, InputDialectOptions inputOptions,
105105
PreprocessingOptions preprocessingOptions,
106106
GlobalOptimizationOptions highLevelOptimizationOptions,
107+
DispatchCreationOptions dispatchCreationOptions,
107108
SchedulingOptions schedulingOptions,
108109
IREE::HAL::TargetOptions halTargetOptions, IREEVMPipelineHooks &hooks,
109110
OpPassManager &passManager,
@@ -120,6 +121,7 @@ void buildIREEVMTransformPassPipeline(
120121
BindingOptions bindingOptions, InputDialectOptions inputOptions,
121122
PreprocessingOptions preprocessingOptions,
122123
GlobalOptimizationOptions highLevelOptimizationOptions,
124+
DispatchCreationOptions dispatchCreationOptions,
123125
SchedulingOptions schedulingOptions,
124126
IREE::HAL::TargetOptions halTargetOptions,
125127
IREE::VM::TargetOptions vmTargetOptions, IREEVMPipelineHooks &hooks,

experimental/regression_suite/shark-test-suite-models/sd3/test_clip.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -104,14 +104,11 @@ def SD3_CLIP_COMMON_RUN_FLAGS(
104104
ROCM_COMPILE_FLAGS = [
105105
"--iree-hal-target-backends=rocm",
106106
f"--iree-hip-target={rocm_chip}",
107+
"--iree-opt-level=O3",
107108
"--iree-input-type=torch",
108109
"--iree-opt-const-eval=false",
109-
"--iree-global-opt-propagate-transposes=true",
110-
"--iree-opt-outer-dim-concat=true",
111110
"--iree-hip-waves-per-eu=2",
112111
"--iree-llvmgpu-enable-prefetch",
113-
"--iree-dispatch-creation-enable-aggressive-fusion",
114-
"--iree-opt-aggressively-propagate-transposes=true",
115112
"--iree-codegen-llvmgpu-use-vector-distribution=true",
116113
"--iree-execution-model=async-external",
117114
"--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline,iree-preprocessing-pad-to-intrinsics{pad-target-type=conv})",

experimental/regression_suite/shark-test-suite-models/sd3/test_mmdit.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -83,11 +83,8 @@ def SD3_MMDIT_COMMON_RUN_FLAGS(
8383
ROCM_COMPILE_FLAGS = [
8484
"--iree-hal-target-backends=rocm",
8585
f"--iree-hip-target={rocm_chip}",
86+
"--iree-opt-level=O3",
8687
"--iree-opt-const-eval=false",
87-
"--iree-global-opt-propagate-transposes=true",
88-
"--iree-dispatch-creation-enable-aggressive-fusion=true",
89-
"--iree-opt-aggressively-propagate-transposes=true",
90-
"--iree-opt-outer-dim-concat=true",
9188
"--iree-vm-target-truncate-unsupported-floats",
9289
"--iree-llvmgpu-enable-prefetch=true",
9390
"--iree-opt-data-tiling=false",

0 commit comments

Comments
 (0)