Skip to content

Commit 56afd0f

Browse files
author
Aviad Cohen
committed
[mlir][Linalg]: Add rewrite pattern to fuse fill with reduce operations
1 parent 642e84f commit 56afd0f

File tree

7 files changed

+290
-0
lines changed

7 files changed

+290
-0
lines changed

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1893,6 +1893,34 @@ void populateDecomposeWinogradOpsPatterns(RewritePatternSet &patterns);
18931893
/// convert to a `linalg.dot`.
18941894
void populateContractionOpRankReducingPatterns(RewritePatternSet &patterns);
18951895

1896+
/// Add patterns to fuse a linalg fill operation with a linalg operation.
1897+
/// Add patterns to fold linalg.fill into linalg.reduce by creating a fused
1898+
/// linalg.generic operation.
1899+
/// The fill operation is expected to happen only on the first index
1900+
/// of the reduction dimension. Currently only one reduction dimension is
1901+
/// supported. Given the pattern:
1902+
/// %empty = tensor.empty() : tensor<i8>
1903+
/// %filled = linalg.fill ins(%c0 : i8) outs(%empty : tensor<i8>) ->
1904+
/// tensor<i8> %reduced = linalg.reduce ins(%0 : tensor<147456xi8>)
1905+
/// outs(%filled : tensor<i8>) dimensions = [0]
1906+
/// (%in: i8, %init: i8) {
1907+
/// %3 = arith.addi %in, %init : i8
1908+
/// linalg.yield %3 : i8
1909+
/// }
1910+
/// The pattern is rewritten into:
1911+
/// %empty = tensor.empty() : tensor<i8>
1912+
/// %reduced = linalg.generic ins(%0 : tensor<147456xi8>) outs(%empty :
1913+
/// tensor<i8>) {
1914+
/// ^bb0(%in: i8, %init: i8):
1915+
/// %cst = arith.constant 0 : index
1916+
/// %index = linalg.index %c0 : index
1917+
/// %cmp = arith.cmpi eq, %cst, %index : i1
1918+
/// %sum = arith.select %cmp, %c0, %init : i8
1919+
/// %res = arith.addi %in, %sum : i8
1920+
/// linalg.yield %res : i8
1921+
/// }
1922+
void populateFuseFillOpWithReduceOpPatterns(RewritePatternSet &patterns);
1923+
18961924
} // namespace linalg
18971925
} // namespace mlir
18981926

mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
1515
EraseUnusedOperandsAndResults.cpp
1616
FoldAddIntoDest.cpp
1717
FusePadOpWithLinalgProducer.cpp
18+
FuseFillOpWithReduceOp.cpp
1819
Fusion.cpp
1920
Generalization.cpp
2021
Hoisting.cpp
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
//===- FuseFillOpWithReduceOp.cpp - Fuse linalg fill with reduce producer -===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file implements patterns that fuses a linalg.generic -> tensor.pad op
10+
// chain into a tensor.extract_slice -> linalg.generic -> tensor.insert_slice
11+
// op chain.
12+
//
13+
//===----------------------------------------------------------------------===//
14+
15+
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
16+
17+
#include "mlir/Dialect/Linalg/IR/Linalg.h"
18+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
19+
20+
using namespace mlir;
21+
22+
namespace {
23+
24+
/// Fold linalg.fill into linalg.reduce by creating a fused linalg.generic
25+
/// operation. The fill operation is expected to happen only on the first index
26+
/// of the reduction dimension. Currently only one reduction dimension is
27+
/// supported. Given the pattern:
28+
/// %empty = tensor.empty() : tensor<i8>
29+
/// %filled = linalg.fill ins(%c0 : i8) outs(%empty : tensor<i8>) ->
30+
/// tensor<i8> %reduced = linalg.reduce ins(%0 : tensor<147456xi8>)
31+
/// outs(%filled : tensor<i8>) dimensions = [0]
32+
/// (%in: i8, %init: i8) {
33+
/// %3 = arith.addi %in, %init : i8
34+
/// linalg.yield %3 : i8
35+
/// }
36+
/// The pattern is rewritten into:
37+
/// %empty = tensor.empty() : tensor<i8>
38+
/// %reduced = linalg.generic ins(%0 : tensor<147456xi8>) outs(%empty :
39+
/// tensor<i8>) {
40+
/// ^bb0(%in: i8, %init: i8):
41+
/// %cst = arith.constant 0 : index
42+
/// %index = linalg.index %c0 : index
43+
/// %cmp = arith.cmpi eq, %cst, %index : i1
44+
/// %sum = arith.select %cmp, %c0, %init : i8
45+
/// %res = arith.addi %in, %sum : i8
46+
/// linalg.yield %res : i8
47+
/// }
48+
struct FoldFillWithReduceOp : public OpRewritePattern<linalg::ReduceOp> {
49+
using OpRewritePattern<linalg::ReduceOp>::OpRewritePattern;
50+
51+
LogicalResult matchAndRewrite(linalg::ReduceOp reduceOp,
52+
PatternRewriter &rewriter) const override {
53+
if (!reduceOp.hasPureTensorSemantics())
54+
return rewriter.notifyMatchFailure(
55+
reduceOp, "skip reduce op with non-pure tensor semantics");
56+
if (reduceOp.getDimensions().size() != 1)
57+
return rewriter.notifyMatchFailure(
58+
reduceOp, "skip reduce op with non-single dimension");
59+
if (reduceOp.getNumDpsInputs() != 1 || reduceOp.getNumDpsInits() != 1)
60+
return rewriter.notifyMatchFailure(
61+
reduceOp, "skip reduce op with multiple number of inputs/results");
62+
auto fillOp = reduceOp.getInits()[0].getDefiningOp<linalg::FillOp>();
63+
if (!fillOp)
64+
return rewriter.notifyMatchFailure(
65+
reduceOp,
66+
"skip reduce op with inits not directly based on fill operation");
67+
68+
long dim = reduceOp.getDimensions()[0];
69+
// Note: on success, the `reduceOp` is replaced with a genericOp and no
70+
// longer valid.
71+
auto failureOrGenericOp = linalg::generalizeNamedOp(rewriter, reduceOp);
72+
if (failed(failureOrGenericOp))
73+
return rewriter.notifyMatchFailure(reduceOp,
74+
"failed to generalize reduce op");
75+
76+
linalg::GenericOp genericReduceOp = *failureOrGenericOp;
77+
auto operandIdx = -1;
78+
for (auto &use : genericReduceOp->getOpOperands()) {
79+
if (use.get().getDefiningOp() == fillOp)
80+
operandIdx = use.getOperandNumber();
81+
}
82+
assert(operandIdx != -1 && "fill op not found in reduce op uses");
83+
84+
Location loc = genericReduceOp.getLoc();
85+
auto blockArg = genericReduceOp.getMatchingBlockArgument(
86+
&genericReduceOp->getOpOperand(operandIdx));
87+
rewriter.setInsertionPointToStart(genericReduceOp.getBody());
88+
auto constZeroIndexOp = rewriter.create<arith::ConstantIndexOp>(loc, 0);
89+
auto linalgIndexOp = rewriter.create<linalg::IndexOp>(loc, dim);
90+
auto cmpIOp = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
91+
constZeroIndexOp.getResult(),
92+
linalgIndexOp.getResult());
93+
auto selectOp =
94+
rewriter.create<arith::SelectOp>(loc, cmpIOp, fillOp.value(), blockArg);
95+
rewriter.replaceAllUsesExcept(blockArg, selectOp.getResult(), selectOp);
96+
genericReduceOp->setOperand(operandIdx, fillOp.getDpsInitOperand(0)->get());
97+
98+
return success();
99+
}
100+
};
101+
102+
} // namespace
103+
104+
void mlir::linalg::populateFuseFillOpWithReduceOpPatterns(
105+
RewritePatternSet &patterns) {
106+
patterns.add<FoldFillWithReduceOp>(patterns.getContext());
107+
}
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
// RUN: mlir-opt -test-linalg-fuse-fill-op-with-reduce-op -split-input-file %s | FileCheck %s
2+
3+
// CHECK-LABEL: func.func private @test_reduce_sum_kernel(
4+
// CHECK-SAME: %[[VAL_0:.*]]: tensor<147456xi8>) -> tensor<i8> {
5+
// CHECK: %[[VAL_1:.*]] = arith.constant 0 : index
6+
// CHECK: %[[VAL_2:.*]] = arith.constant 0 : i8
7+
// CHECK: %[[VAL_3:.*]] = tensor.empty() : tensor<i8>
8+
// CHECK: %[[VAL_4:.*]] = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["reduction"]} ins(%[[VAL_0]] : tensor<147456xi8>) outs(%[[VAL_3]] : tensor<i8>) {
9+
// CHECK: ^bb0(%[[VAL_5:.*]]: i8, %[[VAL_6:.*]]: i8):
10+
// CHECK: %[[VAL_7:.*]] = linalg.index 0 : index
11+
// CHECK: %[[VAL_8:.*]] = arith.cmpi eq, %[[VAL_7]], %[[VAL_1]] : index
12+
// CHECK: %[[VAL_9:.*]] = arith.select %[[VAL_8]], %[[VAL_2]], %[[VAL_6]] : i8
13+
// CHECK: %[[VAL_10:.*]] = arith.addi %[[VAL_5]], %[[VAL_9]] : i8
14+
// CHECK: linalg.yield %[[VAL_10]] : i8
15+
// CHECK: } -> tensor<i8>
16+
// CHECK: return %[[VAL_11:.*]] : tensor<i8>
17+
// CHECK: }
18+
19+
func.func private @test_reduce_sum_kernel(%arg0: tensor<147456xi8>) -> (tensor<i8>) {
20+
%1 = tensor.empty() : tensor<i8>
21+
%c0_i8 = arith.constant 0 : i8
22+
%2 = linalg.fill ins(%c0_i8 : i8) outs(%1 : tensor<i8>) -> tensor<i8>
23+
%reduced = linalg.reduce ins(%arg0 : tensor<147456xi8>) outs(%2 : tensor<i8>) dimensions = [0]
24+
(%in: i8, %init: i8) {
25+
%3 = arith.addi %in, %init : i8
26+
linalg.yield %3 : i8
27+
}
28+
return %reduced : tensor<i8>
29+
}
30+
31+
// -----
32+
33+
func.func private @test_missing_fill(%arg0: tensor<147456xi8>) -> (tensor<i8>) {
34+
%1 = tensor.empty() : tensor<i8>
35+
// CHECK: linalg.reduce
36+
%reduced = linalg.reduce ins(%arg0 : tensor<147456xi8>) outs(%1 : tensor<i8>) dimensions = [0]
37+
(%in: i8, %init: i8) {
38+
%3 = arith.addi %in, %init : i8
39+
linalg.yield %3 : i8
40+
}
41+
return %reduced : tensor<i8>
42+
}
43+
44+
// -----
45+
46+
// CHECK-LABEL: func.func private @test_reduce_multiply_kernel(
47+
// CHECK-SAME: %[[VAL_0:.*]]: tensor<147456xi8>) -> tensor<i8> {
48+
// CHECK: %[[VAL_1:.*]] = arith.constant 0 : index
49+
// CHECK: %[[VAL_2:.*]] = arith.constant 1 : i8
50+
// CHECK: %[[VAL_3:.*]] = tensor.empty() : tensor<i8>
51+
// CHECK: %[[VAL_4:.*]] = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["reduction"]} ins(%[[VAL_0]] : tensor<147456xi8>) outs(%[[VAL_3]] : tensor<i8>) {
52+
// CHECK: ^bb0(%[[VAL_5:.*]]: i8, %[[VAL_6:.*]]: i8):
53+
// CHECK: %[[VAL_7:.*]] = linalg.index 0 : index
54+
// CHECK: %[[VAL_8:.*]] = arith.cmpi eq, %[[VAL_7]], %[[VAL_1]] : index
55+
// CHECK: %[[VAL_9:.*]] = arith.select %[[VAL_8]], %[[VAL_2]], %[[VAL_6]] : i8
56+
// CHECK: %[[VAL_10:.*]] = arith.muli %[[VAL_5]], %[[VAL_9]] : i8
57+
// CHECK: linalg.yield %[[VAL_10]] : i8
58+
// CHECK: } -> tensor<i8>
59+
// CHECK: return %[[VAL_11:.*]] : tensor<i8>
60+
// CHECK: }
61+
62+
func.func private @test_reduce_multiply_kernel(%arg0: tensor<147456xi8>) -> (tensor<i8>) {
63+
%1 = tensor.empty() : tensor<i8>
64+
%c1_i8 = arith.constant 1 : i8
65+
%2 = linalg.fill ins(%c1_i8 : i8) outs(%1 : tensor<i8>) -> tensor<i8>
66+
%reduced = linalg.reduce ins(%arg0 : tensor<147456xi8>) outs(%2 : tensor<i8>) dimensions = [0]
67+
(%in: i8, %init: i8) {
68+
%3 = arith.muli %in, %init : i8
69+
linalg.yield %3 : i8
70+
}
71+
return %reduced : tensor<i8>
72+
}
73+
74+
// -----
75+
76+
func.func private @test_reduce_sum_on_multiple_dims(%arg0: tensor<2x147456xi8>) -> (tensor<i8>) {
77+
%1 = tensor.empty() : tensor<i8>
78+
%c0_i8 = arith.constant 0 : i8
79+
// CHECK: linalg.fill
80+
%2 = linalg.fill ins(%c0_i8 : i8) outs(%1 : tensor<i8>) -> tensor<i8>
81+
// CHECK: linalg.reduce
82+
%reduced = linalg.reduce ins(%arg0 : tensor<2x147456xi8>) outs(%2 : tensor<i8>) dimensions = [0, 1]
83+
(%in: i8, %init: i8) {
84+
%3 = arith.addi %in, %init : i8
85+
linalg.yield %3 : i8
86+
}
87+
return %reduced : tensor<i8>
88+
}

mlir/test/lib/Dialect/Linalg/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ add_mlir_library(MLIRLinalgTestPasses
66
TestLinalgElementwiseFusion.cpp
77
TestLinalgFusionTransforms.cpp
88
TestLinalgRankReduceContractionOps.cpp
9+
TestLinalgFuseFillOpWithReduceOp.cpp
910
TestLinalgTransforms.cpp
1011
TestPadFusion.cpp
1112

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
//===- TestLinalgFuseFillOpWithReduceOp.cpp -----------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file implements a pass for testing fuse linalg fill with linalg reduce
10+
// into a new linalg generic operation.
11+
//
12+
//===----------------------------------------------------------------------===//
13+
14+
#include "mlir/Dialect/Func/IR/FuncOps.h"
15+
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
16+
#include "mlir/Pass/Pass.h"
17+
#include "mlir/Pass/PassManager.h"
18+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
19+
20+
using namespace mlir;
21+
22+
namespace {
23+
24+
struct TestLinalgFuseFillOpWithReduceOp
25+
: public PassWrapper<TestLinalgFuseFillOpWithReduceOp,
26+
OperationPass<func::FuncOp>> {
27+
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLinalgFuseFillOpWithReduceOp)
28+
29+
TestLinalgFuseFillOpWithReduceOp() = default;
30+
TestLinalgFuseFillOpWithReduceOp(
31+
const TestLinalgFuseFillOpWithReduceOp &pass) = default;
32+
void getDependentDialects(DialectRegistry &registry) const override {
33+
registry.insert<arith::ArithDialect, linalg::LinalgDialect,
34+
tensor::TensorDialect>();
35+
}
36+
StringRef getArgument() const final {
37+
return "test-linalg-fuse-fill-op-with-reduce-op";
38+
}
39+
StringRef getDescription() const final {
40+
return "Test fuse linalg fill with linalg reduce into a new linalg generic "
41+
"operation";
42+
}
43+
44+
void runOnOperation() override {
45+
MLIRContext *context = &this->getContext();
46+
func::FuncOp funcOp = this->getOperation();
47+
48+
RewritePatternSet patterns(context);
49+
linalg::populateFuseFillOpWithReduceOpPatterns(patterns);
50+
if (failed(applyPatternsGreedily(funcOp.getBody(), std::move(patterns))))
51+
return signalPassFailure();
52+
}
53+
};
54+
55+
} // namespace
56+
57+
namespace mlir {
58+
namespace test {
59+
void registerTestLinalgFuseFillOpWithReduceOp() {
60+
PassRegistration<TestLinalgFuseFillOpWithReduceOp>();
61+
}
62+
} // namespace test
63+
} // namespace mlir

mlir/tools/mlir-opt/mlir-opt.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ void registerTestLinalgDropUnitDims();
111111
void registerTestLinalgElementwiseFusion();
112112
void registerTestLinalgGreedyFusion();
113113
void registerTestLinalgRankReduceContractionOps();
114+
void registerTestLinalgFuseFillOpWithReduceOp();
114115
void registerTestLinalgTransforms();
115116
void registerTestLivenessAnalysisPass();
116117
void registerTestLivenessPass();
@@ -251,6 +252,7 @@ void registerTestPasses() {
251252
mlir::test::registerTestLinalgElementwiseFusion();
252253
mlir::test::registerTestLinalgGreedyFusion();
253254
mlir::test::registerTestLinalgRankReduceContractionOps();
255+
mlir::test::registerTestLinalgFuseFillOpWithReduceOp();
254256
mlir::test::registerTestLinalgTransforms();
255257
mlir::test::registerTestLivenessAnalysisPass();
256258
mlir::test::registerTestLivenessPass();

0 commit comments

Comments
 (0)