Skip to content

Commit c8ec0f4

Browse files
committed
[mlir][linalg][elementwise] Fold transpose into new elementwise
1 parent 23aca2f commit c8ec0f4

File tree

3 files changed

+98
-1
lines changed

3 files changed

+98
-1
lines changed

mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -601,12 +601,24 @@ def ElementwiseOp : LinalgStructuredBase_Op<"elementwise", [
601601
[{
602602
buildStructuredOp($_builder, $_state, std::nullopt, inputs, outputs,
603603
attributes, ElementwiseOp::getRegionBuilder());
604-
}]>
604+
}]>,
605+
606+
OpBuilder<(ins "ValueRange":$inputs, "ValueRange":$outputs,
607+
"ElementwiseKindAttr":$kind,
608+
"ArrayAttr":$indexingMaps,
609+
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
610+
[{
611+
$_state.addAttribute("kind", kind);
612+
$_state.addAttribute("indexing_maps", indexingMaps);
613+
buildStructuredOp($_builder, $_state, std::nullopt, inputs, outputs,
614+
attributes, ElementwiseOp::getRegionBuilder());
615+
}]>
605616
];
606617

607618
let hasCustomAssemblyFormat = 1;
608619
let hasFolder = 1;
609620
let hasVerifier = 1;
621+
let hasCanonicalizer = 1;
610622

611623
let extraClassDeclaration = structuredOpsBaseDecls # [{
612624
/// Get the arity enum corresponding to the kind of op, e.g. if arg is

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include "mlir/Dialect/Arith/IR/Arith.h"
1818
#include "mlir/Dialect/Arith/Utils/Utils.h"
1919
#include "mlir/Dialect/Complex/IR/Complex.h"
20+
#include "mlir/Dialect/Linalg/Utils/Utils.h"
2021
#include "mlir/Dialect/Math/IR/Math.h"
2122
#include "mlir/Dialect/MemRef/IR/MemRef.h"
2223
#include "mlir/Dialect/SCF/IR/SCF.h"
@@ -4285,6 +4286,47 @@ Speculation::Speculatability ElementwiseOp::getSpeculatability() {
42854286
return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
42864287
}
42874288

4289+
namespace {
4290+
struct FoldTranspose : public OpRewritePattern<ElementwiseOp> {
4291+
using OpRewritePattern<ElementwiseOp>::OpRewritePattern;
4292+
4293+
LogicalResult matchAndRewrite(ElementwiseOp op,
4294+
PatternRewriter &rewriter) const override {
4295+
bool changed = false;
4296+
SmallVector<Value> newIns;
4297+
SmallVector<AffineMap> newMaps;
4298+
for (OpOperand *operand : op.getDpsInputOperands()) {
4299+
AffineMap map = op.getMatchingIndexingMap(operand);
4300+
auto transposeOp = operand->get().getDefiningOp<TransposeOp>();
4301+
4302+
if (!map.isIdentity() || !transposeOp) {
4303+
// push in original operand and its map.
4304+
newIns.push_back(operand->get());
4305+
newMaps.push_back(map);
4306+
continue;
4307+
}
4308+
newIns.push_back(transposeOp.getInput());
4309+
// push in transposeOp's inverse permutation map.
4310+
newMaps.push_back(transposeOp.getMatchingIndexingMap(
4311+
transposeOp.getDpsInputOperand(0)));
4312+
changed = true;
4313+
}
4314+
if (!changed)
4315+
return failure();
4316+
newMaps.push_back(op.getIndexingMapsArray().back());
4317+
4318+
rewriter.replaceOpWithNewOp<ElementwiseOp>(
4319+
op, newIns, op.getDpsInits()[0], op.getKindAttr(),
4320+
rewriter.getAffineMapArrayAttr(newMaps));
4321+
return success();
4322+
}
4323+
};
4324+
} // namespace
4325+
void ElementwiseOp::getCanonicalizationPatterns(RewritePatternSet &results,
4326+
MLIRContext *context) {
4327+
results.add<FoldTranspose>(context);
4328+
}
4329+
42884330
//===----------------------------------------------------------------------===//
42894331
// PackOp/UnPackOp Common
42904332
//===----------------------------------------------------------------------===//
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
// RUN: mlir-opt %s -canonicalize -split-input-file | FileCheck %s
2+
3+
// CHECK-DAG: #[[IDENTITY:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
4+
// CHECK-DAG: #[[TRANSPOSED:.+]] = affine_map<(d0, d1, d2) -> (d1, d0, d2)>
5+
//
6+
// CHECK: func.func @unary_transpose(%[[A:.+]]: tensor<16x8x32xf32>, %[[B:.+]]: tensor<8x16x32xf32>) -> tensor<8x16x32xf32> {
7+
// CHECK-NEXT: %[[RES:.+]] = linalg.elementwise kind=#linalg.elementwise_kind<exp>
8+
// CHECK-SAME: indexing_maps = [#[[TRANSPOSED]], #[[IDENTITY]]]
9+
// CHECK-SAME: ins(%[[A]] : tensor<16x8x32xf32>) outs(%[[B]] : tensor<8x16x32xf32>) -> tensor<8x16x32xf32>
10+
// CHECK-NEXT: return %[[RES]] : tensor<8x16x32xf32>
11+
//
12+
func.func @unary_transpose(%A : tensor<16x8x32xf32>, %B: tensor<8x16x32xf32>) -> tensor<8x16x32xf32> {
13+
%empty = tensor.empty() : tensor<8x16x32xf32>
14+
%transposed_A = linalg.transpose ins(%A : tensor<16x8x32xf32>) outs(%empty : tensor<8x16x32xf32>) permutation = [1, 0, 2]
15+
%result = linalg.elementwise kind=#linalg.elementwise_kind<exp>
16+
ins(%transposed_A : tensor<8x16x32xf32>) outs(%B: tensor<8x16x32xf32>) -> tensor<8x16x32xf32>
17+
return %result : tensor<8x16x32xf32>
18+
}
19+
20+
// -----
21+
22+
// CHECK-DAG: #[[IDENTITY:.+]] = affine_map<(d0, d1) -> (d0, d1)>
23+
// CHECK-DAG: #[[TRANSPOSED:.+]] = affine_map<(d0, d1) -> (d1, d0)>
24+
//
25+
// CHECK: func.func @binary_transposed(%[[A:.+]]: tensor<?x?xf32>, %[[B:.+]]: tensor<?x?xf32>, %[[C:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32> {
26+
// CHECK-NEXT: %[[RES:.+]] = linalg.elementwise kind=#linalg.elementwise_kind<add>
27+
// CHECK-SAME: indexing_maps = [#[[IDENTITY]], #[[TRANSPOSED]], #[[IDENTITY]]]
28+
// CHECK-SAME: ins(%[[A]], %[[B]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[C]] : tensor<?x?xf32>) -> tensor<?x?xf32>
29+
// CHECK-NEXT: return %[[RES]] : tensor<?x?xf32>
30+
//
31+
func.func @binary_transposed(%A : tensor<?x?xf32>, %B: tensor<?x?xf32>, %C: tensor<?x?xf32>) -> tensor<?x?xf32> {
32+
%c0 = arith.constant 0 : index
33+
%c1 = arith.constant 1 : index
34+
%dim0 = tensor.dim %A, %c0 : tensor<?x?xf32>
35+
%dim1 = tensor.dim %A, %c1 : tensor<?x?xf32>
36+
37+
%empty = tensor.empty(%dim1, %dim0) : tensor<?x?xf32>
38+
%transposed_B = linalg.transpose ins(%B : tensor<?x?xf32>) outs(%empty : tensor<?x?xf32>) permutation = [1, 0]
39+
%result = linalg.elementwise kind=#linalg.elementwise_kind<add>
40+
ins(%A, %transposed_B : tensor<?x?xf32>, tensor<?x?xf32>)
41+
outs(%C: tensor<?x?xf32>) -> tensor<?x?xf32>
42+
return %result : tensor<?x?xf32>
43+
}

0 commit comments

Comments
 (0)