Skip to content

Commit e33dd8c

Browse files
committed
address reviewers' comments.
1 parent c8ec0f4 commit e33dd8c

File tree

7 files changed

+100
-44
lines changed

7 files changed

+100
-44
lines changed

mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -618,7 +618,6 @@ def ElementwiseOp : LinalgStructuredBase_Op<"elementwise", [
618618
let hasCustomAssemblyFormat = 1;
619619
let hasFolder = 1;
620620
let hasVerifier = 1;
621-
let hasCanonicalizer = 1;
622621

623622
let extraClassDeclaration = structuredOpsBaseDecls # [{
624623
/// Get the arity enum corresponding to the kind of op, e.g. if arg is

mlir/include/mlir/Dialect/Linalg/Passes.td

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,11 @@ def LinalgSpecializeGenericOpsPass : Pass<"linalg-specialize-generic-ops"> {
9999
let dependentDialects = ["linalg::LinalgDialect"];
100100
}
101101

102+
def LinalgFoldIntoElementwisePass : Pass<"linalg-fold-into-elementwise"> {
103+
let summary = "Fold transform, broadcast and other ops into elementwise";
104+
let dependentDialects = ["linalg::LinalgDialect"];
105+
}
106+
102107
def LinalgDetensorizePass : InterfacePass<"linalg-detensorize", "FunctionOpInterface"> {
103108
let summary = "Detensorize linalg ops";
104109
let dependentDialects = [];

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1710,6 +1710,10 @@ void populateLinalgNamedOpsGeneralizationPatterns(RewritePatternSet &patterns);
17101710
void populateLinalgGenericOpsSpecializationPatterns(
17111711
RewritePatternSet &patterns);
17121712

1713+
/// Populates `patterns` with patterns that fold operations like
1714+
/// `linalg.transform` into elementwise op map.
1715+
void populateLinalgFoldIntoElementwisePatterns(RewritePatternSet &patterns);
1716+
17131717
/// Linalg decompose convolutions patterns
17141718

17151719
/// Populates patterns to decompose high-D convolution ops into low-D ones.

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Lines changed: 0 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
#include "mlir/Dialect/Arith/IR/Arith.h"
1818
#include "mlir/Dialect/Arith/Utils/Utils.h"
1919
#include "mlir/Dialect/Complex/IR/Complex.h"
20-
#include "mlir/Dialect/Linalg/Utils/Utils.h"
2120
#include "mlir/Dialect/Math/IR/Math.h"
2221
#include "mlir/Dialect/MemRef/IR/MemRef.h"
2322
#include "mlir/Dialect/SCF/IR/SCF.h"
@@ -4286,47 +4285,6 @@ Speculation::Speculatability ElementwiseOp::getSpeculatability() {
42864285
return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
42874286
}
42884287

4289-
namespace {
4290-
struct FoldTranspose : public OpRewritePattern<ElementwiseOp> {
4291-
using OpRewritePattern<ElementwiseOp>::OpRewritePattern;
4292-
4293-
LogicalResult matchAndRewrite(ElementwiseOp op,
4294-
PatternRewriter &rewriter) const override {
4295-
bool changed = false;
4296-
SmallVector<Value> newIns;
4297-
SmallVector<AffineMap> newMaps;
4298-
for (OpOperand *operand : op.getDpsInputOperands()) {
4299-
AffineMap map = op.getMatchingIndexingMap(operand);
4300-
auto transposeOp = operand->get().getDefiningOp<TransposeOp>();
4301-
4302-
if (!map.isIdentity() || !transposeOp) {
4303-
// push in original operand and its map.
4304-
newIns.push_back(operand->get());
4305-
newMaps.push_back(map);
4306-
continue;
4307-
}
4308-
newIns.push_back(transposeOp.getInput());
4309-
// push in transposeOp's inverse permutation map.
4310-
newMaps.push_back(transposeOp.getMatchingIndexingMap(
4311-
transposeOp.getDpsInputOperand(0)));
4312-
changed = true;
4313-
}
4314-
if (!changed)
4315-
return failure();
4316-
newMaps.push_back(op.getIndexingMapsArray().back());
4317-
4318-
rewriter.replaceOpWithNewOp<ElementwiseOp>(
4319-
op, newIns, op.getDpsInits()[0], op.getKindAttr(),
4320-
rewriter.getAffineMapArrayAttr(newMaps));
4321-
return success();
4322-
}
4323-
};
4324-
} // namespace
4325-
void ElementwiseOp::getCanonicalizationPatterns(RewritePatternSet &results,
4326-
MLIRContext *context) {
4327-
results.add<FoldTranspose>(context);
4328-
}
4329-
43304288
//===----------------------------------------------------------------------===//
43314289
// PackOp/UnPackOp Common
43324290
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
1414
EliminateEmptyTensors.cpp
1515
EraseUnusedOperandsAndResults.cpp
1616
FoldAddIntoDest.cpp
17+
FoldIntoElementwise.cpp
1718
FusePadOpWithLinalgProducer.cpp
1819
Fusion.cpp
1920
Generalization.cpp
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
//===- FoldIntoElementwise.cpp - Fold Ops into elementwise if possible ---===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file implements folding ops such as transpose and broadcast into the
10+
// affine maps of the elementwise op.
11+
//
12+
//===----------------------------------------------------------------------===//
13+
14+
#include "mlir/Dialect/Linalg/IR/Linalg.h"
15+
#include "mlir/Dialect/Linalg/Passes.h"
16+
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
17+
#include "mlir/IR/PatternMatch.h"
18+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
19+
#include "llvm/ADT/SmallVector.h"
20+
#include "llvm/ADT/TypeSwitch.h"
21+
22+
namespace mlir {
23+
#define GEN_PASS_DEF_LINALGFOLDINTOELEMENTWISEPASS
24+
#include "mlir/Dialect/Linalg/Passes.h.inc"
25+
} // namespace mlir
26+
27+
using namespace mlir;
28+
using namespace mlir::linalg;
29+
30+
#define DEBUG_TYPE "linalg-fold-into-elementwise"
31+
32+
namespace {
33+
struct FoldTransposePattern : public OpRewritePattern<ElementwiseOp> {
34+
using OpRewritePattern<ElementwiseOp>::OpRewritePattern;
35+
36+
LogicalResult matchAndRewrite(ElementwiseOp op,
37+
PatternRewriter &rewriter) const override {
38+
bool changed = false;
39+
SmallVector<Value> newIns;
40+
SmallVector<AffineMap> newMaps;
41+
for (OpOperand *operand : op.getDpsInputOperands()) {
42+
AffineMap map = op.getMatchingIndexingMap(operand);
43+
auto transposeOp = operand->get().getDefiningOp<TransposeOp>();
44+
45+
if (!map.isIdentity() || !transposeOp) {
46+
// push in original operand and its map.
47+
newIns.push_back(operand->get());
48+
newMaps.push_back(map);
49+
continue;
50+
}
51+
newIns.push_back(transposeOp.getInput());
52+
// push in transposeOp's inverse permutation map.
53+
newMaps.push_back(transposeOp.getMatchingIndexingMap(
54+
transposeOp.getDpsInputOperand(0)));
55+
changed = true;
56+
}
57+
if (!changed)
58+
return failure();
59+
newMaps.push_back(op.getIndexingMapsArray().back());
60+
61+
rewriter.replaceOpWithNewOp<ElementwiseOp>(
62+
op, newIns, op.getDpsInits()[0], op.getKindAttr(),
63+
rewriter.getAffineMapArrayAttr(newMaps));
64+
return success();
65+
}
66+
};
67+
68+
struct LinalgFoldIntoElementwisePass
69+
: public impl::LinalgFoldIntoElementwisePassBase<
70+
LinalgFoldIntoElementwisePass> {
71+
using impl::LinalgFoldIntoElementwisePassBase<
72+
LinalgFoldIntoElementwisePass>::LinalgFoldIntoElementwisePassBase;
73+
74+
void runOnOperation() override {
75+
llvm::outs() << "Hellow from fold into elemenwise \n";
76+
Operation *op = getOperation();
77+
RewritePatternSet patterns(op->getContext());
78+
populateLinalgFoldIntoElementwisePatterns(patterns);
79+
80+
if (failed(applyPatternsGreedily(op, std::move(patterns))))
81+
return signalPassFailure();
82+
}
83+
};
84+
} // namespace
85+
86+
void mlir::linalg::populateLinalgFoldIntoElementwisePatterns(
87+
RewritePatternSet &patterns) {
88+
patterns.add<FoldTransposePattern>(patterns.getContext());
89+
}

mlir/test/Dialect/Linalg/elementwise/fold.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: mlir-opt %s -canonicalize -split-input-file | FileCheck %s
1+
// RUN: mlir-opt %s -linalg-fold-into-elementwise -split-input-file | FileCheck %s
22

33
// CHECK-DAG: #[[IDENTITY:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
44
// CHECK-DAG: #[[TRANSPOSED:.+]] = affine_map<(d0, d1, d2) -> (d1, d0, d2)>

0 commit comments

Comments
 (0)