-
Notifications
You must be signed in to change notification settings - Fork 15.4k
[mlir][Linalg]: Add rewrite pattern to fuse fill with reduce operations #125401
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1893,6 +1893,34 @@ void populateDecomposeWinogradOpsPatterns(RewritePatternSet &patterns); | |
| /// convert to a `linalg.dot`. | ||
| void populateContractionOpRankReducingPatterns(RewritePatternSet &patterns); | ||
|
|
||
| /// Add patterns to fuse a linalg fill operation with a linalg operation. | ||
| /// Add patterns to fold linalg.fill into linalg.reduce by creating a fused | ||
| /// linalg.generic operation. | ||
| /// The fill operation is expected to happen only on the first index | ||
| /// of the reduction dimension. Currently only one reduction dimension is | ||
| /// supported. Given the pattern: | ||
| /// %empty = tensor.empty() : tensor<i8> | ||
| /// %filled = linalg.fill ins(%c0 : i8) outs(%empty : tensor<i8>) -> | ||
| /// tensor<i8> %reduced = linalg.reduce ins(%0 : tensor<147456xi8>) | ||
| /// outs(%filled : tensor<i8>) dimensions = [0] | ||
| /// (%in: i8, %init: i8) { | ||
| /// %3 = arith.addi %in, %init : i8 | ||
| /// linalg.yield %3 : i8 | ||
| /// } | ||
| /// The pattern is rewritten into: | ||
| /// %empty = tensor.empty() : tensor<i8> | ||
| /// %reduced = linalg.generic ins(%0 : tensor<147456xi8>) outs(%empty : | ||
| /// tensor<i8>) { | ||
| /// ^bb0(%in: i8, %init: i8): | ||
| /// %cst = arith.constant 0 : index | ||
| /// %index = linalg.index %c0 : index | ||
| /// %cmp = arith.cmpi eq, %cst, %index : i1 | ||
|
Comment on lines
+1916
to
+1917
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The iteration order of reduction and parallel iterators in a linalg operation is undefined. I don't think you can assume that the first iteration of the iterator is 0.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. To be clear, I don't think you can do this fusion while adding linalg.index to the body, because that would mean you are assuming the first iteration index to be something. |
||
| /// %sum = arith.select %cmp, %c0, %init : i8 | ||
| /// %res = arith.addi %in, %sum : i8 | ||
| /// linalg.yield %res : i8 | ||
| /// } | ||
| void populateFuseFillOpWithReduceOpPatterns(RewritePatternSet &patterns); | ||
|
|
||
| } // namespace linalg | ||
| } // namespace mlir | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,107 @@ | ||
| //===- FuseFillOpWithReduceOp.cpp - Fuse linalg fill with reduce producer -===// | ||
| // | ||
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
| // See https://llvm.org/LICENSE.txt for license information. | ||
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
| // | ||
| //===----------------------------------------------------------------------===// | ||
| // | ||
| // This file implements patterns that fuses a linalg.generic -> tensor.pad op | ||
| // chain into a tensor.extract_slice -> linalg.generic -> tensor.insert_slice | ||
| // op chain. | ||
| // | ||
| //===----------------------------------------------------------------------===// | ||
|
|
||
| #include "mlir/Dialect/Linalg/Transforms/Transforms.h" | ||
|
|
||
| #include "mlir/Dialect/Linalg/IR/Linalg.h" | ||
| #include "mlir/Transforms/GreedyPatternRewriteDriver.h" | ||
|
|
||
| using namespace mlir; | ||
|
|
||
| namespace { | ||
|
|
||
| /// Fold linalg.fill into linalg.reduce by creating a fused linalg.generic | ||
| /// operation. The fill operation is expected to happen only on the first index | ||
| /// of the reduction dimension. Currently only one reduction dimension is | ||
| /// supported. Given the pattern: | ||
| /// %empty = tensor.empty() : tensor<i8> | ||
| /// %filled = linalg.fill ins(%c0 : i8) outs(%empty : tensor<i8>) -> | ||
| /// tensor<i8> %reduced = linalg.reduce ins(%0 : tensor<147456xi8>) | ||
| /// outs(%filled : tensor<i8>) dimensions = [0] | ||
| /// (%in: i8, %init: i8) { | ||
| /// %3 = arith.addi %in, %init : i8 | ||
| /// linalg.yield %3 : i8 | ||
| /// } | ||
| /// The pattern is rewritten into: | ||
| /// %empty = tensor.empty() : tensor<i8> | ||
| /// %reduced = linalg.generic ins(%0 : tensor<147456xi8>) outs(%empty : | ||
| /// tensor<i8>) { | ||
| /// ^bb0(%in: i8, %init: i8): | ||
| /// %cst = arith.constant 0 : index | ||
| /// %index = linalg.index %c0 : index | ||
| /// %cmp = arith.cmpi eq, %cst, %index : i1 | ||
| /// %sum = arith.select %cmp, %c0, %init : i8 | ||
| /// %res = arith.addi %in, %sum : i8 | ||
| /// linalg.yield %res : i8 | ||
| /// } | ||
| struct FoldFillWithReduceOp : public OpRewritePattern<linalg::ReduceOp> { | ||
| using OpRewritePattern<linalg::ReduceOp>::OpRewritePattern; | ||
|
|
||
| LogicalResult matchAndRewrite(linalg::ReduceOp reduceOp, | ||
| PatternRewriter &rewriter) const override { | ||
| if (!reduceOp.hasPureTensorSemantics()) | ||
| return rewriter.notifyMatchFailure( | ||
| reduceOp, "skip reduce op with non-pure tensor semantics"); | ||
| if (reduceOp.getDimensions().size() != 1) | ||
| return rewriter.notifyMatchFailure( | ||
| reduceOp, "skip reduce op with non-single dimension"); | ||
| if (reduceOp.getNumDpsInputs() != 1 || reduceOp.getNumDpsInits() != 1) | ||
| return rewriter.notifyMatchFailure( | ||
| reduceOp, "skip reduce op with multiple number of inputs/results"); | ||
| auto fillOp = reduceOp.getInits()[0].getDefiningOp<linalg::FillOp>(); | ||
| if (!fillOp) | ||
| return rewriter.notifyMatchFailure( | ||
| reduceOp, | ||
| "skip reduce op with inits not directly based on fill operation"); | ||
|
|
||
| long dim = reduceOp.getDimensions()[0]; | ||
| // Note: on success, the `reduceOp` is replaced with a genericOp and no | ||
| // longer valid. | ||
| auto failureOrGenericOp = linalg::generalizeNamedOp(rewriter, reduceOp); | ||
| if (failed(failureOrGenericOp)) | ||
| return rewriter.notifyMatchFailure(reduceOp, | ||
| "failed to generalize reduce op"); | ||
|
|
||
| linalg::GenericOp genericReduceOp = *failureOrGenericOp; | ||
| auto operandIdx = -1; | ||
| for (auto &use : genericReduceOp->getOpOperands()) { | ||
| if (use.get().getDefiningOp() == fillOp) | ||
| operandIdx = use.getOperandNumber(); | ||
| } | ||
| assert(operandIdx != -1 && "fill op not found in reduce op uses"); | ||
|
|
||
| Location loc = genericReduceOp.getLoc(); | ||
| auto blockArg = genericReduceOp.getMatchingBlockArgument( | ||
| &genericReduceOp->getOpOperand(operandIdx)); | ||
| rewriter.setInsertionPointToStart(genericReduceOp.getBody()); | ||
| auto constZeroIndexOp = rewriter.create<arith::ConstantIndexOp>(loc, 0); | ||
| auto linalgIndexOp = rewriter.create<linalg::IndexOp>(loc, dim); | ||
| auto cmpIOp = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, | ||
| constZeroIndexOp.getResult(), | ||
| linalgIndexOp.getResult()); | ||
| auto selectOp = | ||
| rewriter.create<arith::SelectOp>(loc, cmpIOp, fillOp.value(), blockArg); | ||
| rewriter.replaceAllUsesExcept(blockArg, selectOp.getResult(), selectOp); | ||
| genericReduceOp->setOperand(operandIdx, fillOp.getDpsInitOperand(0)->get()); | ||
|
|
||
| return success(); | ||
| } | ||
| }; | ||
|
|
||
| } // namespace | ||
|
|
||
| void mlir::linalg::populateFuseFillOpWithReduceOpPatterns( | ||
| RewritePatternSet &patterns) { | ||
| patterns.add<FoldFillWithReduceOp>(patterns.getContext()); | ||
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,88 @@ | ||
| // RUN: mlir-opt -test-linalg-fuse-fill-op-with-reduce-op -split-input-file %s | FileCheck %s | ||
|
|
||
| // CHECK-LABEL: func.func private @test_reduce_sum_kernel( | ||
| // CHECK-SAME: %[[VAL_0:.*]]: tensor<147456xi8>) -> tensor<i8> { | ||
| // CHECK: %[[VAL_1:.*]] = arith.constant 0 : index | ||
| // CHECK: %[[VAL_2:.*]] = arith.constant 0 : i8 | ||
| // CHECK: %[[VAL_3:.*]] = tensor.empty() : tensor<i8> | ||
| // CHECK: %[[VAL_4:.*]] = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["reduction"]} ins(%[[VAL_0]] : tensor<147456xi8>) outs(%[[VAL_3]] : tensor<i8>) { | ||
| // CHECK: ^bb0(%[[VAL_5:.*]]: i8, %[[VAL_6:.*]]: i8): | ||
| // CHECK: %[[VAL_7:.*]] = linalg.index 0 : index | ||
| // CHECK: %[[VAL_8:.*]] = arith.cmpi eq, %[[VAL_7]], %[[VAL_1]] : index | ||
| // CHECK: %[[VAL_9:.*]] = arith.select %[[VAL_8]], %[[VAL_2]], %[[VAL_6]] : i8 | ||
| // CHECK: %[[VAL_10:.*]] = arith.addi %[[VAL_5]], %[[VAL_9]] : i8 | ||
| // CHECK: linalg.yield %[[VAL_10]] : i8 | ||
| // CHECK: } -> tensor<i8> | ||
| // CHECK: return %[[VAL_11:.*]] : tensor<i8> | ||
| // CHECK: } | ||
|
|
||
| func.func private @test_reduce_sum_kernel(%arg0: tensor<147456xi8>) -> (tensor<i8>) { | ||
| %1 = tensor.empty() : tensor<i8> | ||
| %c0_i8 = arith.constant 0 : i8 | ||
| %2 = linalg.fill ins(%c0_i8 : i8) outs(%1 : tensor<i8>) -> tensor<i8> | ||
| %reduced = linalg.reduce ins(%arg0 : tensor<147456xi8>) outs(%2 : tensor<i8>) dimensions = [0] | ||
| (%in: i8, %init: i8) { | ||
| %3 = arith.addi %in, %init : i8 | ||
| linalg.yield %3 : i8 | ||
| } | ||
| return %reduced : tensor<i8> | ||
| } | ||
|
|
||
| // ----- | ||
|
|
||
| func.func private @test_missing_fill(%arg0: tensor<147456xi8>) -> (tensor<i8>) { | ||
| %1 = tensor.empty() : tensor<i8> | ||
| // CHECK: linalg.reduce | ||
| %reduced = linalg.reduce ins(%arg0 : tensor<147456xi8>) outs(%1 : tensor<i8>) dimensions = [0] | ||
| (%in: i8, %init: i8) { | ||
| %3 = arith.addi %in, %init : i8 | ||
| linalg.yield %3 : i8 | ||
| } | ||
| return %reduced : tensor<i8> | ||
| } | ||
|
|
||
| // ----- | ||
|
|
||
| // CHECK-LABEL: func.func private @test_reduce_multiply_kernel( | ||
| // CHECK-SAME: %[[VAL_0:.*]]: tensor<147456xi8>) -> tensor<i8> { | ||
| // CHECK: %[[VAL_1:.*]] = arith.constant 0 : index | ||
| // CHECK: %[[VAL_2:.*]] = arith.constant 1 : i8 | ||
| // CHECK: %[[VAL_3:.*]] = tensor.empty() : tensor<i8> | ||
| // CHECK: %[[VAL_4:.*]] = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["reduction"]} ins(%[[VAL_0]] : tensor<147456xi8>) outs(%[[VAL_3]] : tensor<i8>) { | ||
| // CHECK: ^bb0(%[[VAL_5:.*]]: i8, %[[VAL_6:.*]]: i8): | ||
| // CHECK: %[[VAL_7:.*]] = linalg.index 0 : index | ||
| // CHECK: %[[VAL_8:.*]] = arith.cmpi eq, %[[VAL_7]], %[[VAL_1]] : index | ||
| // CHECK: %[[VAL_9:.*]] = arith.select %[[VAL_8]], %[[VAL_2]], %[[VAL_6]] : i8 | ||
| // CHECK: %[[VAL_10:.*]] = arith.muli %[[VAL_5]], %[[VAL_9]] : i8 | ||
| // CHECK: linalg.yield %[[VAL_10]] : i8 | ||
| // CHECK: } -> tensor<i8> | ||
| // CHECK: return %[[VAL_11:.*]] : tensor<i8> | ||
| // CHECK: } | ||
|
|
||
| func.func private @test_reduce_multiply_kernel(%arg0: tensor<147456xi8>) -> (tensor<i8>) { | ||
| %1 = tensor.empty() : tensor<i8> | ||
| %c1_i8 = arith.constant 1 : i8 | ||
| %2 = linalg.fill ins(%c1_i8 : i8) outs(%1 : tensor<i8>) -> tensor<i8> | ||
| %reduced = linalg.reduce ins(%arg0 : tensor<147456xi8>) outs(%2 : tensor<i8>) dimensions = [0] | ||
| (%in: i8, %init: i8) { | ||
| %3 = arith.muli %in, %init : i8 | ||
| linalg.yield %3 : i8 | ||
| } | ||
| return %reduced : tensor<i8> | ||
| } | ||
|
|
||
| // ----- | ||
|
|
||
| func.func private @test_reduce_sum_on_multiple_dims(%arg0: tensor<2x147456xi8>) -> (tensor<i8>) { | ||
| %1 = tensor.empty() : tensor<i8> | ||
| %c0_i8 = arith.constant 0 : i8 | ||
| // CHECK: linalg.fill | ||
| %2 = linalg.fill ins(%c0_i8 : i8) outs(%1 : tensor<i8>) -> tensor<i8> | ||
| // CHECK: linalg.reduce | ||
| %reduced = linalg.reduce ins(%arg0 : tensor<2x147456xi8>) outs(%2 : tensor<i8>) dimensions = [0, 1] | ||
| (%in: i8, %init: i8) { | ||
| %3 = arith.addi %in, %init : i8 | ||
| linalg.yield %3 : i8 | ||
| } | ||
| return %reduced : tensor<i8> | ||
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,63 @@ | ||
| //===- TestLinalgFuseFillOpWithReduceOp.cpp -----------------------------===// | ||
| // | ||
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
| // See https://llvm.org/LICENSE.txt for license information. | ||
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
| // | ||
| //===----------------------------------------------------------------------===// | ||
| // | ||
| // This file implements a pass for testing fuse linalg fill with linalg reduce | ||
| // into a new linalg generic operation. | ||
| // | ||
| //===----------------------------------------------------------------------===// | ||
|
|
||
| #include "mlir/Dialect/Func/IR/FuncOps.h" | ||
| #include "mlir/Dialect/Linalg/Transforms/Transforms.h" | ||
| #include "mlir/Pass/Pass.h" | ||
| #include "mlir/Pass/PassManager.h" | ||
| #include "mlir/Transforms/GreedyPatternRewriteDriver.h" | ||
|
|
||
| using namespace mlir; | ||
|
|
||
| namespace { | ||
|
|
||
| struct TestLinalgFuseFillOpWithReduceOp | ||
| : public PassWrapper<TestLinalgFuseFillOpWithReduceOp, | ||
| OperationPass<func::FuncOp>> { | ||
| MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLinalgFuseFillOpWithReduceOp) | ||
|
|
||
| TestLinalgFuseFillOpWithReduceOp() = default; | ||
| TestLinalgFuseFillOpWithReduceOp( | ||
| const TestLinalgFuseFillOpWithReduceOp &pass) = default; | ||
| void getDependentDialects(DialectRegistry ®istry) const override { | ||
| registry.insert<arith::ArithDialect, linalg::LinalgDialect, | ||
| tensor::TensorDialect>(); | ||
| } | ||
| StringRef getArgument() const final { | ||
| return "test-linalg-fuse-fill-op-with-reduce-op"; | ||
| } | ||
| StringRef getDescription() const final { | ||
| return "Test fuse linalg fill with linalg reduce into a new linalg generic " | ||
| "operation"; | ||
| } | ||
|
|
||
| void runOnOperation() override { | ||
| MLIRContext *context = &this->getContext(); | ||
| func::FuncOp funcOp = this->getOperation(); | ||
|
|
||
| RewritePatternSet patterns(context); | ||
| linalg::populateFuseFillOpWithReduceOpPatterns(patterns); | ||
| if (failed(applyPatternsGreedily(funcOp.getBody(), std::move(patterns)))) | ||
| return signalPassFailure(); | ||
| } | ||
| }; | ||
|
|
||
| } // namespace | ||
|
|
||
| namespace mlir { | ||
| namespace test { | ||
| void registerTestLinalgFuseFillOpWithReduceOp() { | ||
| PassRegistration<TestLinalgFuseFillOpWithReduceOp>(); | ||
| } | ||
| } // namespace test | ||
| } // namespace mlir |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In general this is not the prefered way of fusing a fill with a reduction. The preferred way is to use tile + fuse approach to fuse at a tile granularity (since fusing fill with its consumer reduction operations results in an inherently imperfectly nested loop computation). THe main issue here is this adds a conditional to the innermost loop computation which isnt what is generally performant.
But this seems still valid. Could you add some comments explaining the alternatives.