Skip to content

Commit 5a4632f

Browse files
authored
Fix Dispatch Creation TransformOptions (iree-org#21964)
Moves all the members from `DispatchCreationPipelineOptions` into `TransformOptions` and removes `DispatchCreationOptions` member, similar to what is done in Stream. `TransformOptions` can now be used to directly parse pipeline flags into the struct. Signed-off-by: Ian Wood <[email protected]>
1 parent be52c3d commit 5a4632f

File tree

3 files changed

+36
-49
lines changed

3 files changed

+36
-49
lines changed

compiler/src/iree/compiler/DispatchCreation/Passes.cpp

Lines changed: 13 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -109,8 +109,7 @@ static void addCleanupPatterns(OpPassManager &passManager) {
109109
//===----------------------------------------------------------------------===//
110110

111111
static void addDispatchRegionCreationPreprocessingPasses(
112-
OpPassManager &passManager,
113-
const DispatchCreationOptions &dispatchOptions) {
112+
OpPassManager &passManager, const TransformOptions &dispatchOptions) {
114113
// 1. Do some simple elementwise op fusion. This could be skipped,
115114
// but could reduce the surface area of ops to handle later.
116115
FunctionLikeNest(passManager)
@@ -203,9 +202,8 @@ static void addDispatchRegionCreationPreprocessingPasses(
203202
// Note that we should not hoist out small constants before the dispatch regions
204203
// are converted to workgroups. E.g., the `cseConstant` option needs to be false
205204
// in greedy pattern rewriting drivers.
206-
static void
207-
addDispatchRegionCreationPasses(OpPassManager &passManager,
208-
const DispatchCreationOptions &options) {
205+
static void addDispatchRegionCreationPasses(OpPassManager &passManager,
206+
const TransformOptions &options) {
209207
FunctionLikeNest(passManager)
210208
// Create dispatches for scalar operations as roots.
211209
.addPass(DispatchCreation::createFormScalarDispatchesPass)
@@ -331,9 +329,8 @@ void buildDispatchCreationPassPipeline(
331329
.addPass(IREE::Flow::createCanonicalizePass)
332330
.addPass(mlir::createCSEPass);
333331

334-
addDispatchRegionCreationPreprocessingPasses(passManager,
335-
transformOptions.options);
336-
addDispatchRegionCreationPasses(passManager, transformOptions.options);
332+
addDispatchRegionCreationPreprocessingPasses(passManager, transformOptions);
333+
addDispatchRegionCreationPasses(passManager, transformOptions);
337334

338335
FunctionLikeNest(passManager)
339336
.addPass(DispatchCreation::createConvertDispatchRegionsToWorkgroupsPass)
@@ -377,40 +374,12 @@ void registerDispatchCreationPasses() {
377374

378375
void registerDispatchCreationPipelines() {
379376

380-
/// Helper struct when registering pass pipeline options.
381-
struct DispatchCreationPipelineOptions
382-
: public PassPipelineOptions<DispatchCreationPipelineOptions> {
383-
Option<bool> aggressiveFusion{
384-
*this,
385-
"aggressive-fusion",
386-
llvm::cl::desc(
387-
"Enable aggressive fusion for dispatch creation pipeline"),
388-
llvm::cl::init(false),
389-
};
390-
Option<bool> dataTiling{
391-
*this,
392-
"data-tiling",
393-
llvm::cl::desc("Enable data-tiling for dispatch creation pipeline"),
394-
llvm::cl::init(false),
395-
};
396-
397-
std::unique_ptr<TransformOptions> toTransformOptions() const {
398-
auto options = std::make_unique<TransformOptions>();
399-
options->options.enableAggressiveFusion = aggressiveFusion;
400-
options->options.dataTiling = dataTiling;
401-
return options;
402-
}
403-
};
404-
405-
PassPipelineRegistration<DispatchCreationPipelineOptions>
406-
dispatchCreationPipeline(
407-
"iree-dispatch-creation-pipeline",
408-
"Flag used to run passes that form dispatch regions",
409-
[](OpPassManager &passManager,
410-
const DispatchCreationPipelineOptions &options) {
411-
buildDispatchCreationPassPipeline(passManager,
412-
*(options.toTransformOptions()));
413-
});
377+
PassPipelineRegistration<TransformOptions> dispatchCreationPipeline(
378+
"iree-dispatch-creation-pipeline",
379+
"Flag used to run passes that form dispatch regions",
380+
[](OpPassManager &passManager, const TransformOptions &options) {
381+
buildDispatchCreationPassPipeline(passManager, options);
382+
});
414383

415384
PassPipelineRegistration<TransformOptions>
416385
dispatchCreationPreprocessingPipeline(
@@ -419,8 +388,8 @@ void registerDispatchCreationPipelines() {
419388
"dispatch region formation. Used only for testing",
420389
[](OpPassManager &passManager,
421390
const TransformOptions &transformOptions) {
422-
addDispatchRegionCreationPreprocessingPasses(
423-
passManager, transformOptions.options);
391+
addDispatchRegionCreationPreprocessingPasses(passManager,
392+
transformOptions);
424393
});
425394
}
426395

compiler/src/iree/compiler/DispatchCreation/Passes.h

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
#include <functional>
1111

1212
#include "iree/compiler/Dialect/TensorExt/IR/TensorExtDialect.h"
13-
#include "iree/compiler/Pipelines/Options.h"
1413
#include "mlir/IR/BuiltinOps.h"
1514
#include "mlir/Interfaces/FunctionInterfaces.h"
1615
#include "mlir/Pass/Pass.h"
@@ -24,10 +23,25 @@ enum class EncodingOptions { Padding, MatmulK, Generic };
2423
// Pipelines
2524
//===----------------------------------------------------------------------===//
2625

27-
/// This is a placeholder for future. We should pass all the options through the
28-
/// struct.
2926
struct TransformOptions : public PassPipelineOptions<TransformOptions> {
30-
DispatchCreationOptions options;
27+
Option<bool> enableAggressiveFusion{
28+
*this,
29+
"aggressive-fusion",
30+
llvm::cl::desc("Enable aggressive fusion for dispatch creation pipeline"),
31+
llvm::cl::init(false),
32+
};
33+
Option<bool> enableFuseMultiUse{
34+
*this,
35+
"fuse-multi-use",
36+
llvm::cl::desc("Fuse operations with multiple uses."),
37+
llvm::cl::init(true),
38+
};
39+
Option<bool> dataTiling{
40+
*this,
41+
"data-tiling",
42+
llvm::cl::desc("Enable data-tiling for dispatch creation pipeline"),
43+
llvm::cl::init(false),
44+
};
3145
};
3246

3347
void buildDispatchCreationPassPipeline(

compiler/src/iree/compiler/Pipelines/Pipelines.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -277,7 +277,11 @@ void buildIREEVMTransformPassPipeline(
277277
break;
278278
default:
279279
DispatchCreation::TransformOptions dispatchTransformOptions;
280-
dispatchTransformOptions.options = dispatchCreationOptions;
280+
dispatchTransformOptions.enableAggressiveFusion =
281+
dispatchCreationOptions.enableAggressiveFusion;
282+
dispatchTransformOptions.enableFuseMultiUse =
283+
dispatchCreationOptions.enableFuseMultiUse;
284+
dispatchTransformOptions.dataTiling = dispatchCreationOptions.dataTiling;
281285
if (compileFrom < IREEVMPipelinePhase::DispatchCreation) { // late-entry
282286
IREE_TRACE_ADD_BEGIN_FRAME_PASS(passManager, "DispatchCreation");
283287
if (hooks.beforePhase)

0 commit comments

Comments
 (0)