-
Notifications
You must be signed in to change notification settings - Fork 15.4k
[mlir][Linalg]: Add rewrite pattern to fuse fill with reduce operations #125401
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
@llvm/pr-subscribers-mlir-linalg @llvm/pr-subscribers-mlir Author: Aviad Cohen (AviadCo) ChangesFull diff: https://github.com/llvm/llvm-project/pull/125401.diff 4 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 1dc42f71e10eff..4b325aaeab87ca 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1893,6 +1893,34 @@ void populateDecomposeWinogradOpsPatterns(RewritePatternSet &patterns);
/// convert to a `linalg.dot`.
void populateContractionOpRankReducingPatterns(RewritePatternSet &patterns);
+/// Add patterns to fuse a linalg fill operation with a linalg operation.
+/// Add patterns to fold linalg.fill into linalg.reduce by creating a fused
+/// linalg.generic operation.
+/// The fill operation is expected to happen only on the first index
+/// of the reduction dimension. Currently only one reduction dimension is
+/// supported. Given the pattern:
+/// %empty = tensor.empty() : tensor<i8>
+/// %filled = linalg.fill ins(%c0 : i8) outs(%empty : tensor<i8>) ->
+/// tensor<i8> %reduced = linalg.reduce ins(%0 : tensor<147456xi8>)
+/// outs(%filled : tensor<i8>) dimensions = [0]
+/// (%in: i8, %init: i8) {
+/// %3 = arith.addi %in, %init : i8
+/// linalg.yield %3 : i8
+/// }
+/// The pattern is rewritten into:
+/// %empty = tensor.empty() : tensor<i8>
+/// %reduced = linalg.generic ins(%0 : tensor<147456xi8>) outs(%empty :
+/// tensor<i8>) {
+/// ^bb0(%in: i8, %init: i8):
+/// %cst = arith.constant 0 : index
+/// %index = linalg.index %c0 : index
+/// %cmp = arith.cmpi eq, %cst, %index : i1
+/// %sum = arith.select %cmp, %c0, %init : i8
+/// %res = arith.addi %in, %sum : i8
+/// linalg.yield %res : i8
+/// }
+void populateFuseFillOpWithReduceOpPatterns(RewritePatternSet &patterns);
+
} // namespace linalg
} // namespace mlir
diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index 3594b084138124..cace3dcb6cbfca 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -15,6 +15,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
EraseUnusedOperandsAndResults.cpp
FoldAddIntoDest.cpp
FusePadOpWithLinalgProducer.cpp
+ FuseFillOpWithReduceOp.cpp
Fusion.cpp
Generalization.cpp
Hoisting.cpp
diff --git a/mlir/test/lib/Dialect/Linalg/CMakeLists.txt b/mlir/test/lib/Dialect/Linalg/CMakeLists.txt
index eb6f581252181a..2c2cef60428743 100644
--- a/mlir/test/lib/Dialect/Linalg/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/Linalg/CMakeLists.txt
@@ -6,6 +6,7 @@ add_mlir_library(MLIRLinalgTestPasses
TestLinalgElementwiseFusion.cpp
TestLinalgFusionTransforms.cpp
TestLinalgRankReduceContractionOps.cpp
+ TestLinalgFuseFillOpWithReduceOp.cpp
TestLinalgTransforms.cpp
TestPadFusion.cpp
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 74007d01347ae8..7e92095ff2fae7 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -111,6 +111,7 @@ void registerTestLinalgDropUnitDims();
void registerTestLinalgElementwiseFusion();
void registerTestLinalgGreedyFusion();
void registerTestLinalgRankReduceContractionOps();
+void registerTestLinalgFuseFillOpWithReduceOp();
void registerTestLinalgTransforms();
void registerTestLivenessAnalysisPass();
void registerTestLivenessPass();
@@ -251,6 +252,7 @@ void registerTestPasses() {
mlir::test::registerTestLinalgElementwiseFusion();
mlir::test::registerTestLinalgGreedyFusion();
mlir::test::registerTestLinalgRankReduceContractionOps();
+ mlir::test::registerTestLinalgFuseFillOpWithReduceOp();
mlir::test::registerTestLinalgTransforms();
mlir::test::registerTestLivenessAnalysisPass();
mlir::test::registerTestLivenessPass();
|
5ebcb8c to
56afd0f
Compare
| /// convert to a `linalg.dot`. | ||
| void populateContractionOpRankReducingPatterns(RewritePatternSet &patterns); | ||
|
|
||
| /// Add patterns to fuse a linalg fill operation with a linalg operation. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In general this is not the prefered way of fusing a fill with a reduction. The preferred way is to use tile + fuse approach to fuse at a tile granularity (since fusing fill with its consumer reduction operations results in an inherently imperfectly nested loop computation). THe main issue here is this adds a conditional to the innermost loop computation which isnt what is generally performant.
But this seems still valid. Could you add some comments explaining the alternatives.
MaheshRavishankar
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Marking request changes for documentation request now. Will review the change itself in a bit.
|
Hey @MaheshRavishankar , I can lower the linalg.generic of fill and reduce into loops and do it then but at that point the flow is much more complicated to identify the pattern of fill + reduce. I do have some patterns to optimize the none nested loops after the lowering but it is more HW specific. I think that some other people might use this tranformation pattern as well as I do. |
| /// %index = linalg.index %c0 : index | ||
| /// %cmp = arith.cmpi eq, %cst, %index : i1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The iteration order of reduction and parallel iterators in a linalg operation is undefined. I don't think you can assume that the first iteration of the iterator is 0.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To be clear, I don't think you can do this fusion while adding linalg.index to the body, because that would mean you are assuming the first iteration index to be something.
I don't think doing this at the linalg.generic level is correct, because you cannot assume a single iteration order. IIUC, what you are trying to do here, is tile the following example on parallel and reduction dims: %empty = linalg.fill
linalg.generic ins(...) outs(%empty) { iterator_types = [parallel, reduction] }
scf.for %i = 0 to ... {
%empty = linalg.fill
scf.for %j = 0 to ... init_args(%arg0 = %empty) {
%out = linalg.generic ins(...) outs(%arg0)
yield %out
}
}
%em = tensor.empty()
scf.for %i = 0 to ... {
scf.for %j = 0 to ... init_args(%arg0 = %em) {
%filled = scf.if (%j == 0) init_args(%arg1 = %arg0) {
%fill = linalg.fill outs(%arg1)
yield %fill
} else {
yield %arg1
}
linalg.generic ins(...) outs(%filled)
}
} |
Thanks @Groverkss . That is correct. For a moment there I forgot this and went into a rabbit hole of "why do we not do this again". btw,
This example is kind of interesting... you are initializing sum on every That is this sequence So to get the final loop sequence, you
Now bufferization + lowering to loops will give you what you expect. |
|
@MaheshRavishankar @Groverkss
Unfortunately, this flow makes the final linalg.generic too naive (works on one element) and our general flow depends on the fact the We do use FuseAndTile pattern and we do co-tile for the |
This was an example. You can use any tile size you want. I wrote that just to show the loop structure that it would generate with tile and fuse. |
No description provided.