Skip to content

Commit 32eccb5

Browse files
committed
Add linalg-morph pass
1 parent 86e19e1 commit 32eccb5

File tree

6 files changed

+154
-27
lines changed

6 files changed

+154
-27
lines changed

mlir/include/mlir/Dialect/Linalg/Passes.td

Lines changed: 45 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,51 @@ def LinalgInlineScalarOperandsPass : Pass<"linalg-inline-scalar-operands"> {
8989
];
9090
}
9191

92+
def LinalgMorphOpsPass : Pass<"linalg-morph-ops"> {
93+
let summary = "Convert named op to category ops or generic and vice-versa";
94+
95+
let description = [{
96+
Convert a linalg op from one representation to another equivalent.
97+
98+
For example, a linalg named op `linalg.add` can also be written as an
99+
category op `linalg.elementwise`, and can also be re-written as
100+
a `linalg.generic`, giving the morphism:
101+
102+
named-op <--> category_op (elementwise, contraction, ..) <--> generic
103+
104+
Generic is a bigger set than named and category ops and so not all generics
105+
can be converted to single category-op or named-op. Similarly, category
106+
ops are bigger set than named ops.
107+
108+
Note:
109+
Legacy converters (will be deprecated):
110+
`--linalg-generalize-named-ops` is the path `named-op --> generic-op`
111+
`--linalg-specialize-generic-ops` is the path `named-op <-- generic-op`
112+
}];
113+
let dependentDialects = ["linalg::LinalgDialect"];
114+
115+
let options = [
116+
// named-op <--> category <--> generic
117+
Option<"namedToCategory", "named-to-category", "bool", /*default=*/"false",
118+
"convert named ops to category op e.g. `linalg.elementwise`">,
119+
120+
Option<"categoryToGeneric", "category-to-generic", "bool", /*default=*/"false",
121+
"convert category ops e.g. `linalg.elementwise` to `linalg.generic`">,
122+
123+
Option<"namedToGeneric", "named-to-generic", "bool", /*default=*/"false",
124+
"convert named ops e.g. `linalg.add` to `linalg.generic`">,
125+
126+
Option<"genericToCategory", "generic-to-category", "bool", /*default=*/"false",
127+
"convert generic ops to category op e.g. `linalg.contraction`">,
128+
129+
Option<"categoryToNamed", "category-to-named", "bool", /*default=*/"false",
130+
"convert category ops to equivalent named ops">,
131+
132+
Option<"genericToNamed", "generic-to-named", "bool", /*default=*/"false",
133+
"convert linalg.generic to equivalent named ops">
134+
];
135+
}
136+
92137
def LinalgGeneralizeNamedOpsPass : Pass<"linalg-generalize-named-ops"> {
93138
let summary = "Convert named ops into generic ops";
94139
let dependentDialects = ["linalg::LinalgDialect"];
@@ -99,11 +144,6 @@ def LinalgSpecializeGenericOpsPass : Pass<"linalg-specialize-generic-ops"> {
99144
let dependentDialects = ["linalg::LinalgDialect"];
100145
}
101146

102-
def LinalgNamedToElementwisePass : Pass<"linalg-named-to-elementwise"> {
103-
let summary = "Convert linalg named ops to elementwise where possible";
104-
let dependentDialects = ["linalg::LinalgDialect"];
105-
}
106-
107147
def LinalgFoldIntoElementwisePass : Pass<"linalg-fold-into-elementwise"> {
108148
let summary = "Fold transform, broadcast and other ops into elementwise";
109149
let dependentDialects = ["linalg::LinalgDialect"];

mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
2323
InlineScalarOperands.cpp
2424
Interchange.cpp
2525
Loops.cpp
26+
MorphOps.cpp
2627
TransposeMatmul.cpp
2728
ShardingInterfaceImpl.cpp
2829
NamedOpConversions.cpp
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
//===- MorphOps.cpp - conversion between named,category and generic ops ---===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file implements conversions between:
10+
// named <--> category (elementwise, contraction, ..) <--> generic ops.
11+
//
12+
// For example, a named op such `linalg.add` can also be re-written as an
13+
// equivalent category op `linalg.elementwise` and also as a `linalg.generic`.
14+
//
15+
// Generic is a bigger set than named ops and so not all generics can be
16+
// converted to single category-op or named-op. Similarly, category-ops
17+
// are bigger in representational possiblities than named ops e.g.
18+
// `linalg.add` has no affine maps attached, but `linalg.elementwise` does.
19+
//
20+
// Note:
21+
// Legacy converters (will be deprecated):
22+
// `--linalg-generalize-named-ops` is the path `named-op --> generic-op`
23+
// `--linalg-specialize-generic-ops` is the path `named-op <-- generic-op`
24+
//===----------------------------------------------------------------------===//
25+
26+
#include "mlir/Dialect/Complex/IR/Complex.h"
27+
#include "mlir/Dialect/Linalg/IR/Linalg.h"
28+
#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
29+
#include "mlir/Dialect/Linalg/Passes.h"
30+
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
31+
#include "mlir/Dialect/Math/IR/Math.h"
32+
#include "mlir/IR/PatternMatch.h"
33+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
34+
35+
namespace mlir {
36+
#define GEN_PASS_DEF_LINALGMORPHOPSPASS
37+
#include "mlir/Dialect/Linalg/Passes.h.inc"
38+
} // namespace mlir
39+
40+
#define DEBUG_TYPE "linalg-morphism"
41+
42+
using namespace mlir;
43+
using namespace mlir::linalg;
44+
45+
namespace {
46+
struct LinalgMorphOpsPass
47+
: public impl::LinalgMorphOpsPassBase<LinalgMorphOpsPass> {
48+
49+
using impl::LinalgMorphOpsPassBase<
50+
LinalgMorphOpsPass>::LinalgMorphOpsPassBase;
51+
52+
void runOnOperation() override;
53+
};
54+
55+
void LinalgMorphOpsPass::runOnOperation() {
56+
57+
RewritePatternSet patterns(&getContext());
58+
59+
// Lowering paths (named -> category -> generic)
60+
if (namedToCategory) {
61+
// TODO: named -> contraction-op
62+
populateLinalgNamedToElementwisePatterns(patterns);
63+
}
64+
if (namedToGeneric || categoryToGeneric) {
65+
populateLinalgNamedOpsGeneralizationPatterns(patterns);
66+
}
67+
68+
// Lifting paths (named <- category <- generic)
69+
if (genericToCategory) {
70+
// TODO.
71+
}
72+
if (categoryToNamed) {
73+
// TODO: if there is a case for this.
74+
}
75+
if (genericToNamed) {
76+
populateLinalgGenericOpsSpecializationPatterns(patterns);
77+
}
78+
79+
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
80+
signalPassFailure();
81+
}
82+
} // namespace

mlir/lib/Dialect/Linalg/Transforms/NamedToElementwise.cpp

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,6 @@
2020
#include "llvm/ADT/SmallVector.h"
2121
#include "llvm/ADT/TypeSwitch.h"
2222

23-
namespace mlir {
24-
#define GEN_PASS_DEF_LINALGNAMEDTOELEMENTWISEPASS
25-
#include "mlir/Dialect/Linalg/Passes.h.inc"
26-
} // namespace mlir
27-
2823
using namespace mlir;
2924
using namespace mlir::linalg;
3025

@@ -76,22 +71,6 @@ struct NamedToElementwisePattern : public OpRewritePattern<NamedOpTy> {
7671
return success();
7772
}
7873
};
79-
80-
struct LinalgNamedToElementwisePass
81-
: public impl::LinalgNamedToElementwisePassBase<
82-
LinalgNamedToElementwisePass> {
83-
using impl::LinalgNamedToElementwisePassBase<
84-
LinalgNamedToElementwisePass>::LinalgNamedToElementwisePassBase;
85-
86-
void runOnOperation() override {
87-
Operation *op = getOperation();
88-
RewritePatternSet patterns(op->getContext());
89-
populateLinalgNamedToElementwisePatterns(patterns);
90-
91-
if (failed(applyPatternsGreedily(op, std::move(patterns))))
92-
return signalPassFailure();
93-
}
94-
};
9574
} // namespace
9675

9776
void mlir::linalg::populateLinalgNamedToElementwisePatterns(

mlir/test/Dialect/Linalg/elementwise/named_to_elementwise.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: mlir-opt %s -linalg-named-to-elementwise -split-input-file | FileCheck %s
1+
// RUN: mlir-opt %s -linalg-morph-ops=named-to-category -split-input-file | FileCheck %s
22

33
// CHECK: @exp(%[[A:.+]]: tensor<16x8xf32>, %[[B:.+]]: tensor<16x8xf32>) -> tensor<16x8xf32> {
44
// CHECK: {{.*}} = linalg.elementwise
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
// Forward path `named -> category -> generic`
2+
// RUN: mlir-opt %s -linalg-morph-ops=named-to-category | FileCheck %s --check-prefix=NAMED_TO_CATEGORY
3+
// RUN: mlir-opt %s -linalg-morph-ops=named-to-generic | FileCheck %s --check-prefix=NAMED_TO_GENERIC
4+
// RUN: mlir-opt %s -linalg-morph-ops=named-to-category | \
5+
// RUN: mlir-opt %s -linalg-morph-ops=category-to-generic | FileCheck %s --check-prefix=CATEGORY_TO_GENERIC
6+
//
7+
// Backward path `named <- category <- generic`
8+
// RUN: mlir-opt %s -linalg-morph-ops=named-to-generic | mlir-opt %s -linalg-morph-ops=generic-to-named | \
9+
// RUN: FileCheck %s --check-prefix=GENERIC_TO_NAMED
10+
11+
func.func @exp(%A : tensor<16x8xf32>, %B : tensor<16x8xf32>) -> tensor<16x8xf32> {
12+
%exp = linalg.exp ins(%A : tensor<16x8xf32>) outs(%B : tensor<16x8xf32>) -> tensor<16x8xf32>
13+
return %exp : tensor<16x8xf32>
14+
}
15+
// NAMED_TO_CATEGORY: linalg.elementwise
16+
// NAMED_TO_CATEGORY-NOT: linalg.exp
17+
18+
// NAMED_TO_GENERIC: linalg.generic
19+
// NAMED_TO_GENERIC-NOT: linalg.exp
20+
21+
// CATEGORY_TO_GENERIC: linalg.generic
22+
// CATEGORY_TO_GENERIC-NOT: linalg.elementwise
23+
24+
// GENERIC_TO_NAMED: linalg.exp
25+
// GENERIC_TO_NAMED-NOT: linalg.generic

0 commit comments

Comments
 (0)