Skip to content

Commit 774c9c6

Browse files
author
MaheshRavishankar
committed
[mlir][Linalg] Add canonicalization of linalg op -> dim op.
Add canonicalization to replace use of the result of a linalg operation on tensors in a dim operation, to use one of the operands of the linalg operations instead. This allows the linalg op itself to be deleted when all its non-dim uses are removed (say through tiling, etc.) Differential Revision: https://reviews.llvm.org/D93076
1 parent 547b032 commit 774c9c6

File tree

7 files changed

+344
-47
lines changed

7 files changed

+344
-47
lines changed

mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,9 @@ def Linalg_Dialect : Dialect {
3232
the op semantics.
3333
}];
3434
let cppNamespace = "::mlir::linalg";
35+
let dependentDialects = [
36+
"AffineDialect", "StandardOpsDialect", "tensor::TensorDialect"
37+
];
3538
}
3639

3740
// Whether a type is a RangeType.

mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -946,6 +946,56 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
946946
return inversePermutation(getLoopsToShapesMap());
947947
}]
948948
>,
949+
InterfaceMethod<
950+
/*desc=*/[{
951+
Return the position in the results of the affine map computed
952+
by getLoopsToShapesMap() that represents the shape of an
953+
operand (input or output) at a dimension.
954+
}],
955+
/*retTy=*/"Optional<unsigned>",
956+
/*methodName=*/"getOperandDimPositionInLoopsToShapeMap",
957+
/*args=*/(ins "unsigned":$operandIdx, "unsigned":$dim),
958+
/*methodBody=*/"",
959+
/*defaultImplementation=*/[{
960+
unsigned pos = 0;
961+
for (auto type : llvm::enumerate(getShapedOperandTypes())) {
962+
if (type.index() == operandIdx) return pos + dim;
963+
pos += type.value().getRank();
964+
}
965+
return {};
966+
}]
967+
>,
968+
InterfaceMethod<
969+
/*desc=*/[{
970+
Return the position in the results of the affine map computed
971+
by getLoopsToShapesMap() that represents the shape of an
972+
input operand at a dimension.
973+
}],
974+
/*retTy=*/"Optional<unsigned>",
975+
/*methodName=*/"getInputValueDimPositionInLoopsToShapeMap",
976+
/*args=*/(ins "unsigned":$inputIdx, "unsigned":$dim),
977+
/*methodBody=*/"",
978+
/*defaultImplementation=*/[{
979+
if (inputIdx >= getNumInputs()) return {};
980+
return getOperandDimPositionInLoopsToShapeMap(inputIdx, dim);
981+
}]
982+
>,
983+
InterfaceMethod<
984+
/*desc=*/[{
985+
Return the position in the results of the affine map computed
986+
by getLoopsToShapesMap() that represents the shape of the
987+
result value at a dimension.
988+
}],
989+
/*retTy=*/"Optional<unsigned>",
990+
/*methodName=*/"getResultValueDimPositionInLoopsToShapeMap",
991+
/*args=*/(ins "unsigned":$resultIdx, "unsigned":$dim),
992+
/*methodBody=*/"",
993+
/*defaultImplementation=*/[{
994+
if (resultIdx >= getNumOutputs()) return {};
995+
return getOperandDimPositionInLoopsToShapeMap(
996+
getNumInputs() + resultIdx, dim);
997+
}]
998+
>,
949999

9501000
//===------------------------------------------------------------------===//
9511001
// Other static interface methods.
@@ -1027,6 +1077,12 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
10271077
}
10281078
return res;
10291079
}
1080+
1081+
/// Returns the value that expresses the shape of the output in terms of
1082+
/// shape of the input operands where possible
1083+
Optional<Value> inferResultDimFromInputShapes
1084+
(OpBuilder &b, Location loc, unsigned resultIdx, unsigned im);
1085+
10301086
//========================================================================//
10311087
// Helper functions to mutate the `operand_segment_sizes` attribute.
10321088
// These are useful when cloning and changing operand types.

mlir/include/mlir/Dialect/Linalg/IR/LinalgTypes.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@
99
#ifndef MLIR_DIALECT_LINALG_LINALGTYPES_H_
1010
#define MLIR_DIALECT_LINALG_LINALGTYPES_H_
1111

12+
#include "mlir/Dialect/Affine/IR/AffineOps.h"
13+
#include "mlir/Dialect/StandardOps/IR/Ops.h"
14+
#include "mlir/Dialect/Tensor/IR/Tensor.h"
1215
#include "mlir/IR/Dialect.h"
1316
#include "mlir/IR/Types.h"
1417

mlir/include/mlir/IR/AffineExprVisitor.h

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -159,29 +159,29 @@ template <typename SubClass, typename RetTy = void> class AffineExprVisitor {
159159

160160
// Default visit methods. Note that the default op-specific binary op visit
161161
// methods call the general visitAffineBinaryOpExpr visit method.
162-
void visitAffineBinaryOpExpr(AffineBinaryOpExpr expr) {}
163-
void visitAddExpr(AffineBinaryOpExpr expr) {
164-
static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
162+
RetTy visitAffineBinaryOpExpr(AffineBinaryOpExpr expr) { return RetTy(); }
163+
RetTy visitAddExpr(AffineBinaryOpExpr expr) {
164+
return static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
165165
}
166-
void visitMulExpr(AffineBinaryOpExpr expr) {
167-
static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
166+
RetTy visitMulExpr(AffineBinaryOpExpr expr) {
167+
return static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
168168
}
169-
void visitModExpr(AffineBinaryOpExpr expr) {
170-
static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
169+
RetTy visitModExpr(AffineBinaryOpExpr expr) {
170+
return static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
171171
}
172-
void visitFloorDivExpr(AffineBinaryOpExpr expr) {
173-
static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
172+
RetTy visitFloorDivExpr(AffineBinaryOpExpr expr) {
173+
return static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
174174
}
175-
void visitCeilDivExpr(AffineBinaryOpExpr expr) {
176-
static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
175+
RetTy visitCeilDivExpr(AffineBinaryOpExpr expr) {
176+
return static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
177177
}
178-
void visitConstantExpr(AffineConstantExpr expr) {}
179-
void visitDimExpr(AffineDimExpr expr) {}
180-
void visitSymbolExpr(AffineSymbolExpr expr) {}
178+
RetTy visitConstantExpr(AffineConstantExpr expr) { return RetTy(); }
179+
RetTy visitDimExpr(AffineDimExpr expr) { return RetTy(); }
180+
RetTy visitSymbolExpr(AffineSymbolExpr expr) { return RetTy(); }
181181

182182
private:
183183
// Walk the operands - each operand is itself walked in post order.
184-
void walkOperandsPostOrder(AffineBinaryOpExpr expr) {
184+
RetTy walkOperandsPostOrder(AffineBinaryOpExpr expr) {
185185
walkPostOrder(expr.getLHS());
186186
walkPostOrder(expr.getRHS());
187187
}

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Lines changed: 122 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,14 @@
1616
#include "mlir/Dialect/Linalg/EDSC/Intrinsics.h"
1717
#include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
1818
#include "mlir/Dialect/StandardOps/IR/Ops.h"
19+
#include "mlir/IR/AffineExprVisitor.h"
1920
#include "mlir/IR/Matchers.h"
2021
#include "mlir/IR/OpImplementation.h"
2122
#include "mlir/IR/PatternMatch.h"
2223

2324
#include "llvm/ADT/DenseMap.h"
2425
#include "llvm/ADT/SetVector.h"
26+
#include "llvm/ADT/SmallSet.h"
2527
#include "llvm/ADT/StringSet.h"
2628
#include "llvm/Support/FormatVariadic.h"
2729
#include "llvm/Support/MathExtras.h"
@@ -86,6 +88,82 @@ SmallVector<Range, 4> LinalgOp::createLoopRanges(OpBuilder &b, Location loc) {
8688
return res;
8789
}
8890

91+
/// Visitor to check if any of the given set of positions from AffineDimExprs
92+
/// are used within an AffineExpr.
93+
struct HasAffineDimExprVisitor
94+
: public AffineExprVisitor<HasAffineDimExprVisitor, bool> {
95+
HasAffineDimExprVisitor(llvm::SmallSet<unsigned, 4> &positions)
96+
: positions(positions) {}
97+
98+
bool visitAffineBinaryOpExpr(AffineBinaryOpExpr binaryOpExpr) {
99+
return visit(binaryOpExpr.getLHS()) || visit(binaryOpExpr.getRHS());
100+
}
101+
102+
bool visitDimExpr(AffineDimExpr dimExpr) {
103+
return positions.count(dimExpr.getPosition());
104+
}
105+
106+
bool visitConstantExpr(AffineConstantExpr constExpr) { return false; }
107+
108+
bool visitSymbolExpr(AffineSymbolExpr symbolExpr) { return false; }
109+
110+
private:
111+
llvm::SmallSet<unsigned, 4> positions;
112+
};
113+
114+
Optional<Value> LinalgOp::inferResultDimFromInputShapes(OpBuilder &b,
115+
Location loc,
116+
unsigned resultIdx,
117+
unsigned dim) {
118+
// An example that helps understand the logic below.
119+
// Consider the following expression O(i+j, j) += A(i,k) * B(k, j)
120+
// We want to express the shape of dim 0 of O in terms of shape of the inputs.
121+
// This is achieved as follows.
122+
// loopsToShapesMap = (d0, d1, d2) -> (d0, d2, d2, d1, d0 + d1, d1)
123+
// subMapOfResultDim = (d0, d1, d2) -> (d0 + d1)
124+
// shapesToLoopsMap = (d0, d2, d2, d3, d4, d5) -> (d0, d3, d2)
125+
// resultFromFromInputDim = subMapOfResultDim.compose(shapesToLoopMap)
126+
// = (d0, d1, d2, d3, d4, d5) -> (d0 + d1)
127+
AffineMap loopsToShapesMap = getLoopsToShapesMap();
128+
129+
// Find the position in the above map that represents the shape of the
130+
// result:dim being inferred.
131+
Optional<unsigned> resultDimSubMapPos =
132+
getResultValueDimPositionInLoopsToShapeMap(resultIdx, dim);
133+
if (!resultDimSubMapPos)
134+
return {};
135+
136+
/// From loopsToShapesMap extract the submap that represents the shape of the
137+
/// (resultIdx, dim) needed
138+
AffineMap loopToResultDimShapeMap =
139+
loopsToShapesMap.getSubMap(*resultDimSubMapPos);
140+
AffineMap operandShapesToResultDimMap =
141+
loopToResultDimShapeMap.compose(getShapesToLoopsMap());
142+
143+
// Check that the result dim map does not contain the positions corresponding
144+
// to the outputs.
145+
llvm::SmallSet<unsigned, 4> outputDims;
146+
unsigned outputDimPosStart =
147+
getResultValueDimPositionInLoopsToShapeMap(0, 0).getValue();
148+
unsigned outputDimPosEnd =
149+
getResultValueDimPositionInLoopsToShapeMap(getNumOutputs() - 1,
150+
getOutputOpOperands()
151+
.back()
152+
.get()
153+
.getType()
154+
.cast<ShapedType>()
155+
.getRank() -
156+
1)
157+
.getValue();
158+
llvm::for_each(llvm::seq<unsigned>(outputDimPosStart, outputDimPosEnd),
159+
[&outputDims](unsigned dim) { outputDims.insert(dim); });
160+
HasAffineDimExprVisitor checkDimExpr(outputDims);
161+
if (checkDimExpr.visit(operandShapesToResultDimMap.getResult(0)))
162+
return llvm::None;
163+
return applyMapToValues(b, loc, operandShapesToResultDimMap,
164+
createFlatListOfOperandDims(b, loc))[0];
165+
}
166+
89167
/// Forward declarations.
90168
template <typename NamedStructuredOpType>
91169
static void buildNamedStructuredOpRegionAndAttributes(OpBuilder &opBuilder,
@@ -2022,6 +2100,49 @@ struct FoldTensorCastOp : public RewritePattern {
20222100
return success();
20232101
}
20242102
};
2103+
2104+
/// Replaces std.dim operations that use the result of a LinalgOp (on tensors)
2105+
/// with std.dim operations that use one of the arguments. For example,
2106+
///
2107+
/// %0 = linalg.matmul ins(%arg0, %arg1, ...)
2108+
/// %1 = dim %0, %c0
2109+
///
2110+
/// with
2111+
///
2112+
/// %1 = dim %arg0, %c0
2113+
///
2114+
/// where possible. With this the result of the `linalg.matmul` is not used in
2115+
/// dim operations. If the value produced is replaced with another value (say by
2116+
/// tiling `linalg.matmul`) will make the `linalg.matmul` truly dead instead of
2117+
/// used in a dim op that would prevent the DCE of this op.
2118+
struct ReplaceDimOfLinalgOpResult : public OpRewritePattern<DimOp> {
2119+
using OpRewritePattern<DimOp>::OpRewritePattern;
2120+
2121+
LogicalResult matchAndRewrite(DimOp dimOp,
2122+
PatternRewriter &rewriter) const override {
2123+
Value dimValue = dimOp.memrefOrTensor();
2124+
Optional<int64_t> dimIndex = dimOp.getConstantIndex();
2125+
if (!dimIndex)
2126+
return failure();
2127+
auto linalgOp = dimValue.getDefiningOp<LinalgOp>();
2128+
if (!linalgOp)
2129+
return failure();
2130+
2131+
unsigned resultIndex = dimValue.cast<OpResult>().getResultNumber();
2132+
Optional<Value> operandDimValue = linalgOp.inferResultDimFromInputShapes(
2133+
rewriter, dimOp.getLoc(), resultIndex,
2134+
static_cast<unsigned>(*dimIndex));
2135+
if (!operandDimValue) {
2136+
// Its always possible to replace using the corresponding `outs`
2137+
// parameter.
2138+
operandDimValue = rewriter.create<DimOp>(
2139+
dimOp.getLoc(), linalgOp.getOutput(resultIndex), *dimIndex);
2140+
}
2141+
rewriter.replaceOp(dimOp, *operandDimValue);
2142+
return success();
2143+
}
2144+
};
2145+
20252146
} // namespace
20262147

20272148
namespace {
@@ -2166,34 +2287,14 @@ struct RemoveIdentityLinalgOps : public RewritePattern {
21662287
return success();
21672288
}
21682289
};
2169-
2170-
/// Canonicalize a `linalgOp` -> `dim` pattern by replacing the `dim` arg
2171-
/// with the corresponding output tensor argument of the linalg op.
2172-
struct ReplaceDimOfLinalgResult : public OpRewritePattern<DimOp> {
2173-
using OpRewritePattern<DimOp>::OpRewritePattern;
2174-
2175-
LogicalResult matchAndRewrite(DimOp dimOp,
2176-
PatternRewriter &rewriter) const override {
2177-
Value dimOpArg = dimOp.memrefOrTensor();
2178-
auto linalgOp = dimOpArg.getDefiningOp<LinalgOp>();
2179-
if (!linalgOp)
2180-
return failure();
2181-
2182-
auto results = linalgOp.getOperation()->getResults();
2183-
int64_t id = std::distance(results.begin(), llvm::find(results, dimOpArg));
2184-
auto outputTensors = linalgOp.getOutputTensors();
2185-
rewriter.replaceOpWithNewOp<DimOp>(dimOp, outputTensors[id], dimOp.index());
2186-
return success();
2187-
}
2188-
};
21892290
} // namespace
21902291

21912292
#define CANONICALIZERS_AND_FOLDERS(XXX) \
21922293
void XXX::getCanonicalizationPatterns(OwningRewritePatternList &results, \
21932294
MLIRContext *context) { \
21942295
results.insert<DeduplicateInputs, EraseDeadLinalgOp, FoldTensorCastOp, \
21952296
RemoveIdentityLinalgOps>(); \
2196-
results.insert<ReplaceDimOfLinalgResult>(context); \
2297+
results.insert<ReplaceDimOfLinalgOpResult>(context); \
21972298
} \
21982299
\
21992300
LogicalResult XXX::fold(ArrayRef<Attribute>, \

mlir/lib/Dialect/Linalg/IR/LinalgTypes.cpp

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,9 +58,6 @@ struct LinalgInlinerInterface : public DialectInlinerInterface {
5858
//===----------------------------------------------------------------------===//
5959

6060
void mlir::linalg::LinalgDialect::initialize() {
61-
getContext()->getOrLoadDialect("std");
62-
getContext()->getOrLoadDialect("tensor");
63-
6461
addTypes<RangeType>();
6562
addOperations<
6663
#define GET_OP_LIST

0 commit comments

Comments
 (0)