Skip to content

Commit 1ba85a9

Browse files
Move Decompose concat to Dispatch creation pipeline.
Signed-off-by: MaheshRavishankar <[email protected]>
1 parent bf711a1 commit 1ba85a9

File tree

17 files changed

+54
-35
lines changed

17 files changed

+54
-35
lines changed

β€Žcompiler/src/iree/compiler/DispatchCreation/BUILD.bazelβ€Ž

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ iree_compiler_cc_library(
2222
"CollapseReductionDimensions.cpp",
2323
"ConvertDispatchRegionsToWorkgroups.cpp",
2424
"ConvertTensorToFlow.cpp",
25+
"DecomposeConcat.cpp",
2526
"DispatchWithTransformDialect.cpp",
2627
"ElementwiseOpFusion.cpp",
2728
"FoldUnitExtentDims.cpp",

β€Žcompiler/src/iree/compiler/DispatchCreation/CMakeLists.txtβ€Ž

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ iree_cc_library(
2424
"CollapseReductionDimensions.cpp"
2525
"ConvertDispatchRegionsToWorkgroups.cpp"
2626
"ConvertTensorToFlow.cpp"
27+
"DecomposeConcat.cpp"
2728
"DispatchWithTransformDialect.cpp"
2829
"ElementwiseOpFusion.cpp"
2930
"FoldUnitExtentDims.cpp"
Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
// See https://llvm.org/LICENSE.txt for license information.
55
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66

7-
#include "iree/compiler/GlobalOptimization/Passes.h"
7+
#include "iree/compiler/DispatchCreation/Passes.h"
88
#include "llvm/ADT/STLExtras.h"
99
#include "mlir/Dialect/Linalg/IR/Linalg.h"
1010
#include "mlir/Dialect/Tensor/IR/Tensor.h"
@@ -13,10 +13,10 @@
1313
#include "mlir/IR/PatternMatch.h"
1414
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
1515

16-
namespace mlir::iree_compiler::GlobalOptimization {
16+
namespace mlir::iree_compiler::DispatchCreation {
1717

1818
#define GEN_PASS_DEF_DECOMPOSECONCATPASS
19-
#include "iree/compiler/GlobalOptimization/Passes.h.inc"
19+
#include "iree/compiler/DispatchCreation/Passes.h.inc"
2020

2121
namespace {
2222

β€Žcompiler/src/iree/compiler/DispatchCreation/Passes.cppβ€Ž

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -204,9 +204,20 @@ void addDispatchRegionCreationPreprocessingPasses(OpPassManager &passManager) {
204204

205205
// Pipeline to first create `flow.dispatch.region` ops and then lower to
206206
// `flow.dispatch.workgroup` ops.
207-
static void addDispatchRegionCreationPasses(OpPassManager &passManager) {
207+
static void
208+
addDispatchRegionCreationPasses(OpPassManager &passManager,
209+
const TransformOptions &transformOptions) {
208210
FunctionLikeNest(passManager)
209-
// Only want use the transform dialect for some dispatch regions and let
211+
// We decompose and transpose concatenations immediately before folding
212+
// unit extent dims because this allows decoupling unit dims in the
213+
// concatenation from the transposes that are introduced.
214+
.addPass([&]() {
215+
DecomposeConcatPassOptions options;
216+
options.enableConcatTransposition =
217+
transformOptions.options.outerDimConcat;
218+
return createDecomposeConcatPass(options);
219+
}) // Only want use the transform dialect for some dispatch regions and
220+
// let
210221
// the FormDispatchRegions handle the rest. This only moves the root
211222
// compute op into the dispatch region, so that we can run additional
212223
// transformations afterwards with a simple region and without bothering
@@ -303,7 +314,7 @@ void buildDispatchCreationPassPipeline(
303314
.addPass(mlir::createCSEPass);
304315

305316
addDispatchRegionCreationPreprocessingPasses(passManager);
306-
addDispatchRegionCreationPasses(passManager);
317+
addDispatchRegionCreationPasses(passManager, transformOptions);
307318

308319
FunctionLikeNest(passManager)
309320
.addPass(DispatchCreation::createConvertDispatchRegionsToWorkgroupsPass)

β€Žcompiler/src/iree/compiler/DispatchCreation/Passes.hβ€Ž

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,9 @@ namespace mlir::iree_compiler::DispatchCreation {
2323

2424
/// This is a placeholder for future. We should pass all the options through the
2525
/// struct.
26-
struct TransformOptions : public PassPipelineOptions<TransformOptions> {};
26+
struct TransformOptions : public PassPipelineOptions<TransformOptions> {
27+
DispatchCreationOptions options;
28+
};
2729

2830
void buildDispatchCreationPassPipeline(
2931
OpPassManager &passManager, const TransformOptions &transformOptions);

β€Žcompiler/src/iree/compiler/DispatchCreation/Passes.tdβ€Ž

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,17 @@ def CollapseReductionDimensionsPass :
3939
];
4040
}
4141

42+
def DecomposeConcatPass :
43+
Pass<"iree-dispatch-creation-decompose-concat", ""> {
44+
let summary = "Decomposes concatenations into a destination and a sequence of slice inserts.";
45+
let options = [
46+
Option<"enableConcatTransposition", "enable-concat-transposition", "bool",
47+
/*default=*/"false", "Allows transposing concatenations such that "
48+
"they occur on the inner most dims.">,
49+
];
50+
}
51+
52+
4253
def ElementwiseOpFusionPass :
4354
Pass<"iree-dispatch-creation-elementwise-op-fusion", ""> {
4455
let summary = "Fuse elementwise operations.";

β€Žcompiler/src/iree/compiler/DispatchCreation/test/BUILD.bazelβ€Ž

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ iree_lit_test_suite(
4646
"split_reduction.mlir",
4747
"tensor_pad_to_tensor_insert_slice.mlir",
4848
"transform_dispatch_region_formation.mlir",
49+
"transpose_and_decompose_concat.mlir",
4950
"transpose_generic_ops.mlir",
5051
],
5152
include = ["*.mlir"],

β€Žcompiler/src/iree/compiler/DispatchCreation/test/CMakeLists.txtβ€Ž

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ iree_lit_test_suite(
4444
"split_reduction.mlir"
4545
"tensor_pad_to_tensor_insert_slice.mlir"
4646
"transform_dispatch_region_formation.mlir"
47+
"transpose_and_decompose_concat.mlir"
4748
"transpose_generic_ops.mlir"
4849
TOOLS
4950
FileCheck

β€Žcompiler/src/iree/compiler/GlobalOptimization/BUILD.bazelβ€Ž

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,6 @@ iree_compiler_cc_library(
4747
"CleanupNumericNarrowing.cpp",
4848
"Convert1X1FilterConv2DToMatmul.cpp",
4949
"DataLayoutPropagation.cpp",
50-
"DecomposeConcat.cpp",
5150
"DemoteContractionInputsToBF16.cpp",
5251
"DetachElementwiseFromNamedOps.cpp",
5352
"EraseUnusedLinalgOperands.cpp",

0 commit comments

Comments
Β (0)