Skip to content

Commit ceda56b

Browse files
authored
[mlir][linalg] Morphism across linalg -- named, category and generic ops. (#148424)
Adds `linalg-morph-ops` pass to convert an op from one representation to another: named-op <--> category_op (elementwise, contraction, ..) <--> generic e.g. ```mlir %exp = linalg.exp ins(%A : tensor<16x8xf32>) outs(%B : tensor<16x8xf32>) -> tensor<16x8xf32> ``` After `mlir-opt -linalg-morph-ops=named-to-category ..` ```mlir %0 = linalg.elementwise kind=#linalg.elementwise_kind<exp> ins(%arg0 : tensor<16x8xf32> .. Note: this is generalization of `--linalg-generalize-named-ops` is the path `named-op --> generic-op` `--linalg-specialize-generic-ops` is the path `named-op <-- generic-op` email: [email protected]
1 parent f24c50a commit ceda56b

File tree

8 files changed

+290
-0
lines changed

8 files changed

+290
-0
lines changed

mlir/include/mlir/Dialect/Linalg/Passes.td

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,45 @@ def LinalgInlineScalarOperandsPass : Pass<"linalg-inline-scalar-operands"> {
8989
];
9090
}
9191

92+
def LinalgMorphOpsPass : Pass<"linalg-morph-ops"> {
93+
let summary = "Convert named op to category ops or generic and vice-versa";
94+
95+
let description = [{
96+
Convert a linalg op from one representation to another equivalent.
97+
For example, a linalg named op `linalg.add` can also be written as an
98+
category op `linalg.elementwise`, and can also be re-written as
99+
a `linalg.generic`, giving the morphism:
100+
101+
named-op <--> category_op (elementwise, contraction, ..) <--> generic
102+
103+
Note that the set of `linalg.generic` subsumes named and category ops
104+
and therefore not all `linalg.genric` can be converted to named or
105+
category op. Similarly, catgory ops subsume named ops.
106+
107+
Note:
108+
Legacy converters:
109+
`--linalg-generalize-named-ops` is the path `named-op --> generic-op`
110+
`--linalg-specialize-generic-ops` is the path `named-op <-- generic-op`
111+
}];
112+
let dependentDialects = ["linalg::LinalgDialect"];
113+
114+
let options = [
115+
// named-op <--> category <--> generic
116+
117+
// Lowering options
118+
Option<"namedToCategory", "named-to-category", "bool", /*default=*/"false",
119+
"convert named ops to category op e.g. `linalg.elementwise`">,
120+
Option<"categoryToGeneric", "category-to-generic", "bool", /*default=*/"false",
121+
"convert category ops e.g. `linalg.elementwise` to `linalg.generic`">,
122+
Option<"namedToGeneric", "named-to-generic", "bool", /*default=*/"false",
123+
"convert named ops e.g. `linalg.add` to `linalg.generic`">,
124+
125+
// Lifting options
126+
// TODOs: `generic-to-category`, `category-to-named`
127+
Option<"genericToNamed", "generic-to-named", "bool", /*default=*/"false",
128+
"convert linalg.generic to equivalent named ops"> ];
129+
}
130+
92131
def LinalgGeneralizeNamedOpsPass : Pass<"linalg-generalize-named-ops"> {
93132
let summary = "Convert named ops into generic ops";
94133
let dependentDialects = ["linalg::LinalgDialect"];

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1831,6 +1831,10 @@ void populateLinalgNamedOpsGeneralizationPatterns(RewritePatternSet &patterns);
18311831
void populateLinalgGenericOpsSpecializationPatterns(
18321832
RewritePatternSet &patterns);
18331833

1834+
/// Populates `patterns` that convert linalg named ops e.g. `linalg.add`
1835+
/// to equivalent `linalg.elementwise`.
1836+
void populateLinalgNamedToElementwisePatterns(RewritePatternSet &patterns);
1837+
18341838
/// Populates `patterns` with patterns that fold operations like
18351839
/// `linalg.transform` into elementwise op map.
18361840
void populateLinalgFoldIntoElementwisePatterns(RewritePatternSet &patterns);

mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,11 @@ add_mlir_dialect_library(MLIRLinalgTransforms
2323
InlineScalarOperands.cpp
2424
Interchange.cpp
2525
Loops.cpp
26+
MorphOps.cpp
2627
TransposeMatmul.cpp
2728
ShardingInterfaceImpl.cpp
2829
NamedOpConversions.cpp
30+
NamedToElementwise.cpp
2931
BlockPackMatmul.cpp
3032
PackAndUnpackPatterns.cpp
3133
Padding.cpp
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
//===- MorphOps.cpp - conversion between named,category and generic ops ---===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file implements conversions between linalg ops:
10+
// named <--> category (elementwise, contraction, ..) <--> generic.
11+
//===----------------------------------------------------------------------===//
12+
13+
#include "mlir/Dialect/Complex/IR/Complex.h"
14+
#include "mlir/Dialect/Linalg/IR/Linalg.h"
15+
#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
16+
#include "mlir/Dialect/Linalg/Passes.h"
17+
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
18+
#include "mlir/Dialect/Math/IR/Math.h"
19+
#include "mlir/IR/PatternMatch.h"
20+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
21+
22+
namespace mlir {
23+
#define GEN_PASS_DEF_LINALGMORPHOPSPASS
24+
#include "mlir/Dialect/Linalg/Passes.h.inc"
25+
} // namespace mlir
26+
27+
#define DEBUG_TYPE "linalg-morphism"
28+
29+
using namespace mlir;
30+
using namespace mlir::linalg;
31+
32+
namespace {
33+
struct LinalgMorphOpsPass
34+
: public impl::LinalgMorphOpsPassBase<LinalgMorphOpsPass> {
35+
36+
using impl::LinalgMorphOpsPassBase<
37+
LinalgMorphOpsPass>::LinalgMorphOpsPassBase;
38+
39+
void runOnOperation() override;
40+
};
41+
42+
void LinalgMorphOpsPass::runOnOperation() {
43+
44+
RewritePatternSet patterns(&getContext());
45+
46+
// Lowering paths (named -> category -> generic)
47+
if (namedToCategory) {
48+
populateLinalgNamedToElementwisePatterns(patterns);
49+
}
50+
if (namedToGeneric || categoryToGeneric) {
51+
populateLinalgNamedOpsGeneralizationPatterns(patterns);
52+
}
53+
54+
// Lifting paths (named <- category <- generic)
55+
if (genericToNamed) {
56+
populateLinalgGenericOpsSpecializationPatterns(patterns);
57+
}
58+
59+
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
60+
signalPassFailure();
61+
}
62+
} // namespace
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
//===- NamedToElementwise.cpp - convert linalg named op into elementwise --===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file implements rewriting those linalg named ops that are essentially
10+
// elementwise e.g. `linalg.exp`, to `linalg.elementwise`. This allows further
11+
// optimization on `linalg.elementwise` such as folding transpose, broadcast.
12+
//
13+
//===----------------------------------------------------------------------===//
14+
15+
#include "mlir/Dialect/Linalg/IR/Linalg.h"
16+
#include "mlir/Dialect/Linalg/Passes.h"
17+
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
18+
#include "mlir/IR/PatternMatch.h"
19+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
20+
#include "llvm/ADT/SmallVector.h"
21+
#include "llvm/ADT/TypeSwitch.h"
22+
23+
using namespace mlir;
24+
using namespace mlir::linalg;
25+
26+
#define DEBUG_TYPE "linalg-named-to-elementwise"
27+
28+
namespace {
29+
ElementwiseKind getKind(Operation *op) {
30+
return llvm::TypeSwitch<Operation *, ElementwiseKind>(op)
31+
.Case([](SelectOp) { return ElementwiseKind::select; })
32+
.Case([](AddOp) { return ElementwiseKind::add; })
33+
.Case([](SubOp) { return ElementwiseKind::sub; })
34+
.Case([](MulOp) { return ElementwiseKind::mul; })
35+
.Case([](DivOp) { return ElementwiseKind::div; })
36+
.Case([](DivUnsignedOp) { return ElementwiseKind::div_unsigned; })
37+
.Case([](PowFOp) { return ElementwiseKind::powf; })
38+
.Case([](ExpOp) { return ElementwiseKind::exp; })
39+
.Case([](LogOp) { return ElementwiseKind::log; })
40+
.Case([](AbsOp) { return ElementwiseKind::abs; })
41+
.Case([](CeilOp) { return ElementwiseKind::ceil; })
42+
.Case([](FloorOp) { return ElementwiseKind::floor; })
43+
.Case([](NegFOp) { return ElementwiseKind::negf; })
44+
.Case([](ReciprocalOp) { return ElementwiseKind::reciprocal; })
45+
.Case([](RoundOp) { return ElementwiseKind::round; })
46+
.Case([](SqrtOp) { return ElementwiseKind::sqrt; })
47+
.Case([](RsqrtOp) { return ElementwiseKind::rsqrt; })
48+
.Case([](SquareOp) { return ElementwiseKind::square; })
49+
.Case([](TanhOp) { return ElementwiseKind::tanh; })
50+
.Case([](ErfOp) { return ElementwiseKind::erf; })
51+
.Default([&](Operation *op) {
52+
llvm_unreachable("unhandled case in named to elementwise");
53+
return ElementwiseKind::sub;
54+
});
55+
}
56+
57+
template <typename NamedOpTy>
58+
struct NamedToElementwisePattern : public OpRewritePattern<NamedOpTy> {
59+
using OpRewritePattern<NamedOpTy>::OpRewritePattern;
60+
61+
LogicalResult matchAndRewrite(NamedOpTy op,
62+
PatternRewriter &rewriter) const override {
63+
SmallVector<NamedAttribute> attrs;
64+
auto kindAttr = ElementwiseKindAttr::get(op.getContext(), getKind(op));
65+
attrs.push_back(rewriter.getNamedAttr("kind", kindAttr));
66+
attrs.push_back(
67+
rewriter.getNamedAttr("indexing_maps", op.getIndexingMaps()));
68+
69+
rewriter.replaceOpWithNewOp<ElementwiseOp>(op, op.getDpsInputs(),
70+
op.getDpsInits(), attrs);
71+
return success();
72+
}
73+
};
74+
} // namespace
75+
76+
void mlir::linalg::populateLinalgNamedToElementwisePatterns(
77+
RewritePatternSet &patterns) {
78+
patterns.add<NamedToElementwisePattern<SelectOp>>(patterns.getContext());
79+
patterns.add<NamedToElementwisePattern<AddOp>>(patterns.getContext());
80+
patterns.add<NamedToElementwisePattern<SubOp>>(patterns.getContext());
81+
patterns.add<NamedToElementwisePattern<MulOp>>(patterns.getContext());
82+
patterns.add<NamedToElementwisePattern<DivOp>>(patterns.getContext());
83+
patterns.add<NamedToElementwisePattern<DivUnsignedOp>>(patterns.getContext());
84+
patterns.add<NamedToElementwisePattern<PowFOp>>(patterns.getContext());
85+
patterns.add<NamedToElementwisePattern<ExpOp>>(patterns.getContext());
86+
patterns.add<NamedToElementwisePattern<LogOp>>(patterns.getContext());
87+
patterns.add<NamedToElementwisePattern<AbsOp>>(patterns.getContext());
88+
patterns.add<NamedToElementwisePattern<CeilOp>>(patterns.getContext());
89+
patterns.add<NamedToElementwisePattern<FloorOp>>(patterns.getContext());
90+
patterns.add<NamedToElementwisePattern<NegFOp>>(patterns.getContext());
91+
patterns.add<NamedToElementwisePattern<ReciprocalOp>>(patterns.getContext());
92+
patterns.add<NamedToElementwisePattern<RoundOp>>(patterns.getContext());
93+
patterns.add<NamedToElementwisePattern<SqrtOp>>(patterns.getContext());
94+
patterns.add<NamedToElementwisePattern<RsqrtOp>>(patterns.getContext());
95+
patterns.add<NamedToElementwisePattern<SquareOp>>(patterns.getContext());
96+
patterns.add<NamedToElementwisePattern<TanhOp>>(patterns.getContext());
97+
patterns.add<NamedToElementwisePattern<ErfOp>>(patterns.getContext());
98+
}
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
// RUN: mlir-opt %s -linalg-morph-ops=named-to-category -split-input-file | FileCheck %s
2+
3+
// CHECK: @exp(%[[A:.+]]: tensor<16x8xf32>, %[[B:.+]]: tensor<16x8xf32>) -> tensor<16x8xf32> {
4+
// CHECK: {{.*}} = linalg.elementwise
5+
// CHECK-SAME: kind=#linalg.elementwise_kind<exp>
6+
// CHECK-SAME: ins(%[[A]] : tensor<16x8xf32>)
7+
// CHECK-SAME: outs(%[[B]] : tensor<16x8xf32>) -> tensor<16x8xf32>
8+
//
9+
func.func @exp(%A : tensor<16x8xf32>, %B : tensor<16x8xf32>) -> tensor<16x8xf32> {
10+
%exp = linalg.exp ins(%A : tensor<16x8xf32>) outs(%B : tensor<16x8xf32>) -> tensor<16x8xf32>
11+
return %exp : tensor<16x8xf32>
12+
}
13+
14+
// ----
15+
16+
// CHECK: @add(%[[A:.+]]: tensor<16x8xf32>, %[[B:.+]]: tensor<16x8xf32>, %[[C:.+]]: tensor<16x8xf32>) -> tensor<16x8xf32> {
17+
// CHECK: {{.*}} = linalg.elementwise
18+
// CHECK-SAME: kind=#linalg.elementwise_kind<add>
19+
// CHECK-SAME: ins(%[[A]], %[[B]] : tensor<16x8xf32>, tensor<16x8xf32>)
20+
// CHECK-SAME: outs(%[[C]] : tensor<16x8xf32>) -> tensor<16x8xf32>
21+
//
22+
func.func @add(%A : tensor<16x8xf32>, %B: tensor<16x8xf32>, %C : tensor<16x8xf32>) -> tensor<16x8xf32> {
23+
%add = linalg.add ins(%A, %B : tensor<16x8xf32>, tensor<16x8xf32>) outs(%C : tensor<16x8xf32>) -> tensor<16x8xf32>
24+
return %add : tensor<16x8xf32>
25+
}
26+
27+
// ----
28+
29+
// CHECK: @sub(%[[A:.+]]: tensor<16x8xf32>, %[[B:.+]]: tensor<16x8xf32>, %[[C:.+]]: tensor<16x8xf32>) -> tensor<16x8xf32> {
30+
// CHECK: {{.*}} = linalg.elementwise
31+
// CHECK-SAME: kind=#linalg.elementwise_kind<sub>
32+
// CHECK-SAME: ins(%[[A]], %[[B]] : tensor<16x8xf32>, tensor<16x8xf32>)
33+
// CHECK-SAME: outs(%[[C]] : tensor<16x8xf32>)
34+
//
35+
func.func @sub(%A : tensor<16x8xf32>, %B: tensor<16x8xf32>, %C : tensor<16x8xf32>) -> tensor<16x8xf32> {
36+
%sub = linalg.sub ins(%A, %B : tensor<16x8xf32>, tensor<16x8xf32>) outs(%C : tensor<16x8xf32>) -> tensor<16x8xf32>
37+
return %sub : tensor<16x8xf32>
38+
}
39+
40+
// ----
41+
42+
// CHECK: @ternary_select(%[[A:.+]]: tensor<4x8x16xi1>, %[[B:.+]]: tensor<4x8x16xf32>, %[[C:.+]]: tensor<4x8x16xf32>)
43+
// CHECK: %[[E:.+]] = tensor.empty() : tensor<4x8x16xf32>
44+
// CHECK: {{.*}} = linalg.elementwise
45+
// CHECK-SAME: kind=#linalg.elementwise_kind<select>
46+
// CHECK-SAME: ins(%[[A]], %[[B]], %[[C]] : tensor<4x8x16xi1>, tensor<4x8x16xf32>, tensor<4x8x16xf32>)
47+
// CHECK-SAME: outs(%[[E]] : tensor<4x8x16xf32>) -> tensor<4x8x16xf32>
48+
//
49+
func.func @ternary_select(%A: tensor<4x8x16xi1>, %B: tensor<4x8x16xf32>, %C: tensor<4x8x16xf32>)
50+
-> tensor<4x8x16xf32> {
51+
%empty = tensor.empty() : tensor<4x8x16xf32>
52+
%select = linalg.select
53+
ins(%A, %B, %C : tensor<4x8x16xi1>, tensor<4x8x16xf32>, tensor<4x8x16xf32>)
54+
outs(%empty: tensor<4x8x16xf32>) -> tensor<4x8x16xf32>
55+
return %select : tensor<4x8x16xf32>
56+
}
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
// Forward path `named -> category -> generic`
2+
// RUN: mlir-opt %s -linalg-morph-ops=named-to-category | FileCheck %s --check-prefix=NAMED_TO_CATEGORY
3+
4+
// RUN: mlir-opt %s -linalg-morph-ops=named-to-category | \
5+
// RUN: mlir-opt %s -linalg-morph-ops=category-to-generic | FileCheck %s --check-prefix=CATEGORY_TO_GENERIC
6+
7+
func.func @exp(%A : tensor<16x8xf32>, %B : tensor<16x8xf32>) -> tensor<16x8xf32> {
8+
%exp = linalg.exp ins(%A : tensor<16x8xf32>) outs(%B : tensor<16x8xf32>) -> tensor<16x8xf32>
9+
return %exp : tensor<16x8xf32>
10+
}
11+
// NAMED_TO_CATEGORY: linalg.elementwise
12+
// NAMED_TO_CATEGORY-NOT: linalg.exp
13+
14+
// CATEGORY_TO_GENERIC: linalg.generic
15+
// CATEGORY_TO_GENERIC-NOT: linalg.elementwise
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
// RUN: mlir-opt %s -linalg-morph-ops=named-to-generic | FileCheck %s --check-prefix=NAMED_TO_GENERIC
2+
// RUN: mlir-opt %s -linalg-morph-ops=named-to-generic | mlir-opt %s -linalg-morph-ops=generic-to-named | \
3+
// RUN: FileCheck %s --check-prefix=ROUND_TRIP
4+
5+
func.func @exp(%A : tensor<16x8xf32>, %B : tensor<16x8xf32>) -> tensor<16x8xf32> {
6+
%exp = linalg.exp ins(%A : tensor<16x8xf32>) outs(%B : tensor<16x8xf32>) -> tensor<16x8xf32>
7+
return %exp : tensor<16x8xf32>
8+
}
9+
10+
// NAMED_TO_GENERIC: linalg.generic
11+
// NAMED_TO_GENERIC-NOT: linalg.exp
12+
13+
// ROUND_TRIP: linalg.exp
14+
// ROUND_TRIP-NOT: linalg.generic

0 commit comments

Comments
 (0)