Skip to content

Commit 56c8431

Browse files
committed
[mlir][linalg] Convert linalg.named to linalg.elementwise op.
Convert linalg.named ops which are elementwise (e.g. add/exp) to `linalg.elementwise`. Currently, named ops have to drop to linalg.generic (--generalize-named-ops), where one figures out which generic are elementwise. Also, folding of broadcast or transpose can occur then only at generic level. Instead, with this rewrite, these can happen now at linalg.elementwise.
1 parent 2c0d563 commit 56c8431

File tree

5 files changed

+166
-0
lines changed

5 files changed

+166
-0
lines changed

mlir/include/mlir/Dialect/Linalg/Passes.td

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,11 @@ def LinalgSpecializeGenericOpsPass : Pass<"linalg-specialize-generic-ops"> {
9999
let dependentDialects = ["linalg::LinalgDialect"];
100100
}
101101

102+
def LinalgNamedToElementwisePass : Pass<"linalg-named-to-elementwise"> {
103+
let summary = "Convert linalg named ops to elementwise where possible";
104+
let dependentDialects = ["linalg::LinalgDialect"];
105+
}
106+
102107
def LinalgFoldIntoElementwisePass : Pass<"linalg-fold-into-elementwise"> {
103108
let summary = "Fold transform, broadcast and other ops into elementwise";
104109
let dependentDialects = ["linalg::LinalgDialect"];

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1810,6 +1810,10 @@ void populateLinalgNamedOpsGeneralizationPatterns(RewritePatternSet &patterns);
18101810
void populateLinalgGenericOpsSpecializationPatterns(
18111811
RewritePatternSet &patterns);
18121812

1813+
/// Populates `patterns` that convert linalg named ops e.g. `linalg.add`
1814+
/// to equivalent `linalg.elementwise`.
1815+
void populateLinalgNamedToElementwisePatterns(RewritePatternSet &patterns);
1816+
18131817
/// Populates `patterns` with patterns that fold operations like
18141818
/// `linalg.transform` into elementwise op map.
18151819
void populateLinalgFoldIntoElementwisePatterns(RewritePatternSet &patterns);

mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
2626
TransposeMatmul.cpp
2727
MeshShardingInterfaceImpl.cpp
2828
NamedOpConversions.cpp
29+
NamedToElementwise.cpp
2930
BlockPackMatmul.cpp
3031
PackAndUnpackPatterns.cpp
3132
Padding.cpp
Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
//===- NamedToElementwise.cpp - convert linalg named op into elementwise --===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file implements rewriting those linalg named ops that are essentially
10+
// elementwise e.g. `linalg.exp`, to `linalg.elementwise`. This allows further
11+
// optimization on `linalg.elementwise` such as folding transpose, broadcast.
12+
//
13+
//===----------------------------------------------------------------------===//
14+
15+
#include "mlir/Dialect/Linalg/IR/Linalg.h"
16+
#include "mlir/Dialect/Linalg/Passes.h"
17+
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
18+
#include "mlir/IR/PatternMatch.h"
19+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
20+
#include "llvm/ADT/SmallVector.h"
21+
#include "llvm/ADT/TypeSwitch.h"
22+
23+
namespace mlir {
24+
#define GEN_PASS_DEF_LINALGNAMEDTOELEMENTWISEPASS
25+
#include "mlir/Dialect/Linalg/Passes.h.inc"
26+
} // namespace mlir
27+
28+
using namespace mlir;
29+
using namespace mlir::linalg;
30+
31+
#define DEBUG_TYPE "linalg-named-to-elementwise"
32+
33+
namespace {
34+
ElementwiseKind getKind(Operation *op) {
35+
return llvm::TypeSwitch<Operation *, ElementwiseKind>(op)
36+
.Case([](SelectOp) { return ElementwiseKind::select; })
37+
.Case([](AddOp) { return ElementwiseKind::add; })
38+
.Case([](SubOp) { return ElementwiseKind::sub; })
39+
.Case([](MulOp) { return ElementwiseKind::mul; })
40+
.Case([](DivOp) { return ElementwiseKind::div; })
41+
.Case([](DivUnsignedOp) { return ElementwiseKind::div_unsigned; })
42+
.Case([](PowFOp) { return ElementwiseKind::powf; })
43+
.Case([](ExpOp) { return ElementwiseKind::exp; })
44+
.Case([](LogOp) { return ElementwiseKind::log; })
45+
.Case([](AbsOp) { return ElementwiseKind::abs; })
46+
.Case([](CeilOp) { return ElementwiseKind::ceil; })
47+
.Case([](FloorOp) { return ElementwiseKind::floor; })
48+
.Case([](NegFOp) { return ElementwiseKind::negf; })
49+
.Case([](ReciprocalOp) { return ElementwiseKind::reciprocal; })
50+
.Case([](RoundOp) { return ElementwiseKind::round; })
51+
.Case([](SqrtOp) { return ElementwiseKind::sqrt; })
52+
.Case([](RsqrtOp) { return ElementwiseKind::rsqrt; })
53+
.Case([](SquareOp) { return ElementwiseKind::square; })
54+
.Case([](TanhOp) { return ElementwiseKind::tanh; })
55+
.Case([](ErfOp) { return ElementwiseKind::erf; })
56+
.Default([&](Operation *op) {
57+
assert(false && "unexpected op");
58+
return ElementwiseKind::sub;
59+
});
60+
}
61+
62+
template <typename NamedOpTy>
63+
struct NamedToElementwisePattern : public OpRewritePattern<NamedOpTy> {
64+
using OpRewritePattern<NamedOpTy>::OpRewritePattern;
65+
66+
LogicalResult matchAndRewrite(NamedOpTy op,
67+
PatternRewriter &rewriter) const override {
68+
SmallVector<NamedAttribute> attrs;
69+
auto kindAttr = ElementwiseKindAttr::get(op.getContext(), getKind(op));
70+
attrs.push_back(rewriter.getNamedAttr("kind", kindAttr));
71+
attrs.push_back(
72+
rewriter.getNamedAttr("indexing_maps", op.getIndexingMaps()));
73+
74+
rewriter.replaceOpWithNewOp<ElementwiseOp>(op, op.getDpsInputs(),
75+
op.getDpsInits(), attrs);
76+
return success();
77+
}
78+
};
79+
80+
struct LinalgNamedToElementwisePass
81+
: public impl::LinalgNamedToElementwisePassBase<
82+
LinalgNamedToElementwisePass> {
83+
using impl::LinalgNamedToElementwisePassBase<
84+
LinalgNamedToElementwisePass>::LinalgNamedToElementwisePassBase;
85+
86+
void runOnOperation() override {
87+
Operation *op = getOperation();
88+
RewritePatternSet patterns(op->getContext());
89+
populateLinalgNamedToElementwisePatterns(patterns);
90+
91+
if (failed(applyPatternsGreedily(op, std::move(patterns))))
92+
return signalPassFailure();
93+
}
94+
};
95+
} // namespace
96+
97+
void mlir::linalg::populateLinalgNamedToElementwisePatterns(
98+
RewritePatternSet &patterns) {
99+
patterns.add<NamedToElementwisePattern<AddOp>>(patterns.getContext());
100+
patterns.add<NamedToElementwisePattern<SubOp>>(patterns.getContext());
101+
patterns.add<NamedToElementwisePattern<MulOp>>(patterns.getContext());
102+
patterns.add<NamedToElementwisePattern<DivOp>>(patterns.getContext());
103+
patterns.add<NamedToElementwisePattern<DivUnsignedOp>>(patterns.getContext());
104+
patterns.add<NamedToElementwisePattern<PowFOp>>(patterns.getContext());
105+
patterns.add<NamedToElementwisePattern<ExpOp>>(patterns.getContext());
106+
patterns.add<NamedToElementwisePattern<LogOp>>(patterns.getContext());
107+
patterns.add<NamedToElementwisePattern<AbsOp>>(patterns.getContext());
108+
patterns.add<NamedToElementwisePattern<CeilOp>>(patterns.getContext());
109+
patterns.add<NamedToElementwisePattern<FloorOp>>(patterns.getContext());
110+
patterns.add<NamedToElementwisePattern<NegFOp>>(patterns.getContext());
111+
patterns.add<NamedToElementwisePattern<ReciprocalOp>>(patterns.getContext());
112+
patterns.add<NamedToElementwisePattern<RoundOp>>(patterns.getContext());
113+
patterns.add<NamedToElementwisePattern<SqrtOp>>(patterns.getContext());
114+
patterns.add<NamedToElementwisePattern<RsqrtOp>>(patterns.getContext());
115+
patterns.add<NamedToElementwisePattern<SquareOp>>(patterns.getContext());
116+
patterns.add<NamedToElementwisePattern<TanhOp>>(patterns.getContext());
117+
patterns.add<NamedToElementwisePattern<ErfOp>>(patterns.getContext());
118+
}
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
// RUN: mlir-opt %s -linalg-named-to-elementwise -split-input-file | FileCheck %s
2+
3+
// CHECK: @exp(%[[A:.+]]: tensor<16x8xf32>, %[[B:.+]]: tensor<16x8xf32>) -> tensor<16x8xf32> {
4+
// CHECK: {{.*}} = linalg.elementwise
5+
// CHECK-SAME: kind=#linalg.elementwise_kind<exp>
6+
// CHECK-SAME: ins(%[[A]] : tensor<16x8xf32>)
7+
// CHECK-SAME: outs(%[[B]] : tensor<16x8xf32>) -> tensor<16x8xf32>
8+
//
9+
func.func @exp(%A : tensor<16x8xf32>, %B : tensor<16x8xf32>) -> tensor<16x8xf32> {
10+
%exp = linalg.exp ins(%A : tensor<16x8xf32>) outs(%B : tensor<16x8xf32>) -> tensor<16x8xf32>
11+
return %exp : tensor<16x8xf32>
12+
}
13+
14+
// ----
15+
16+
// CHECK: @add(%[[A:.+]]: tensor<16x8xf32>, %[[B:.+]]: tensor<16x8xf32>, %[[C:.+]]: tensor<16x8xf32>) -> tensor<16x8xf32> {
17+
// CHECK: {{.*}} = linalg.elementwise
18+
// CHECK-SAME: kind=#linalg.elementwise_kind<add>
19+
// CHECK-SAME: ins(%[[A]], %[[B]] : tensor<16x8xf32>, tensor<16x8xf32>)
20+
// CHECK-SAME: outs(%[[C]] : tensor<16x8xf32>) -> tensor<16x8xf32>
21+
//
22+
func.func @add(%A : tensor<16x8xf32>, %B: tensor<16x8xf32>, %C : tensor<16x8xf32>) -> tensor<16x8xf32> {
23+
%add = linalg.add ins(%A, %B : tensor<16x8xf32>, tensor<16x8xf32>) outs(%C : tensor<16x8xf32>) -> tensor<16x8xf32>
24+
return %add : tensor<16x8xf32>
25+
}
26+
27+
// ----
28+
29+
// CHECK: @sub(%[[A:.+]]: memref<16x8xf32>, %[[B:.+]]: memref<16x8xf32>, %[[C:.+]]: memref<16x8xf32>) {
30+
// CHECK: linalg.elementwise
31+
// CHECK-SAME: kind=#linalg.elementwise_kind<sub>
32+
// CHECK-SAME: ins(%[[A]], %[[B]] : memref<16x8xf32>, memref<16x8xf32>)
33+
// CHECK-SAME: outs(%[[C]] : memref<16x8xf32>)
34+
//
35+
func.func @sub(%A : memref<16x8xf32>, %B: memref<16x8xf32>, %C : memref<16x8xf32>) {
36+
linalg.sub ins(%A, %B : memref<16x8xf32>, memref<16x8xf32>) outs(%C : memref<16x8xf32>)
37+
return
38+
}

0 commit comments

Comments
 (0)