Skip to content

Commit a38ba01

Browse files
committed
[mlir][linalg] raise generic to named ops.
Add support for specializing linalg.broadcast and linalg.transform from generic. Also, refactoring to reuse specialization checks.
1 parent 581c015 commit a38ba01

File tree

6 files changed

+180
-27
lines changed

6 files changed

+180
-27
lines changed

mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,16 @@ bool isaConvolutionOpInterface(LinalgOp linalgOp,
120120
/// Checks whether `linalgOp` is semantically equivalent to a `linalg.copyOp`.
121121
bool isaCopyOpInterface(LinalgOp linalgOp);
122122

123+
/// Checks whether `genericOp` is semantically equivalent to a
124+
/// `linalg.broadcast`. Returns broadcast dimensions if true.
125+
std::optional<SmallVector<int64_t>>
126+
isaBroadcastOpInterface(GenericOp genericOp);
127+
128+
/// Checks whether `genericOp` is semantically equivalent to a
129+
/// `linalg.transpose`. Returns permuted dimensions if true.
130+
std::optional<SmallVector<int64_t>>
131+
isaTransposeOpInterface(GenericOp genericOp);
132+
123133
/// Checks whether a given `genericOp` is semantically equivalent to a single
124134
/// linalgelementwise unary op. e.g. linalg.exp.
125135
/// A linalg.generic body could be a series of unary elementwise ops e.g.

mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp

Lines changed: 100 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include "llvm/ADT/SmallBitVector.h"
2323
#include "llvm/ADT/SmallVector.h"
2424
#include <algorithm>
25+
#include <numeric>
2526

2627
using namespace mlir;
2728
using namespace mlir::linalg;
@@ -49,18 +50,41 @@ bool linalg::detail::canOpOperandsBeDroppedImpl(
4950
return inversePermutation(concatAffineMaps(indexingMaps)) != AffineMap();
5051
}
5152

53+
// Returns true if all loops of the linalgOp are parallel
54+
static bool isAllParallel(LinalgOp op) {
55+
return op.getNumParallelLoops() == op.getNumLoops();
56+
}
57+
58+
// Returns true if and only if linalgOp takes one input and one init.
59+
static bool isSingleInputOutput(LinalgOp op) {
60+
return op.getNumDpsInputs() == 1 && op.getNumDpsInits() == 1;
61+
}
62+
// Returns true if genericOp body is just a yieldOp that yields
63+
// input operand as result.
64+
static bool isSingleYieldOp(GenericOp op) {
65+
if (op.getNumDpsInputs() != 1 || op.getNumDpsInits() != 1)
66+
return false;
67+
68+
Block *body = op.getBody();
69+
if (body->getOperations().size() != 1)
70+
return false;
71+
72+
auto yieldOp = dyn_cast<linalg::YieldOp>(body->back());
73+
if (!yieldOp || yieldOp.getNumOperands() != 1 ||
74+
yieldOp->getOperand(0) != body->getArgument(0))
75+
return false;
76+
return true;
77+
}
78+
5279
//===----------------------------------------------------------------------===//
5380
// CopyOpInterface implementation
5481
//===----------------------------------------------------------------------===//
5582

5683
bool linalg::isaCopyOpInterface(LinalgOp linalgOp) {
57-
// Structural.
58-
if (linalgOp.getNumParallelLoops() != linalgOp.getNumLoops())
84+
// Structural and operands
85+
if (!isAllParallel(linalgOp) || !isSingleInputOutput(linalgOp))
5986
return false;
6087

61-
// Operands and maps.
62-
if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1)
63-
return false;
6488
auto mapRange = linalgOp.getIndexingMapsArray();
6589
if (mapRange.size() != 2 || !mapRange.front().isIdentity() ||
6690
!mapRange.back().isIdentity()) {
@@ -75,8 +99,8 @@ bool linalg::isaCopyOpInterface(LinalgOp linalgOp) {
7599
//===----------------------------------------------------------------------===//
76100
std::optional<Value> linalg::isaFillOpInterface(GenericOp genericOp) {
77101
// Structural.
78-
if (genericOp.getNumParallelLoops() != genericOp.getNumLoops() ||
79-
genericOp.getNumDpsInputs() != 1 || genericOp.getNumDpsInits() != 1)
102+
if (!isAllParallel(genericOp) || !isSingleInputOutput(genericOp) ||
103+
!isSingleYieldOp(genericOp))
80104
return std::nullopt;
81105

82106
// Input should be referenced and init should not.
@@ -87,16 +111,78 @@ std::optional<Value> linalg::isaFillOpInterface(GenericOp genericOp) {
87111
OpOperand *value = genericOp.getDpsInputOperand(0);
88112
if (!genericOp.isScalar(value))
89113
return std::nullopt;
114+
return value->get();
115+
}
90116

91-
Block *body = genericOp.getBody();
92-
if (body->getOperations().size() != 1)
117+
//===----------------------------------------------------------------------===//
118+
// BroadcastOpInterface implementation
119+
//===----------------------------------------------------------------------===//
120+
std::optional<SmallVector<int64_t>>
121+
linalg::isaBroadcastOpInterface(GenericOp genericOp) {
122+
// Structural.
123+
if (!isAllParallel(genericOp) || !isSingleInputOutput(genericOp) ||
124+
!isSingleYieldOp(genericOp))
93125
return std::nullopt;
94126

95-
auto yieldOp = dyn_cast<linalg::YieldOp>(body->back());
96-
if (!yieldOp || yieldOp.getNumOperands() != 1 ||
97-
yieldOp->getOperand(0) != body->getArgument(0))
127+
auto t0 = genericOp.getDpsInputOperand(0)->get().getType();
128+
auto t1 = genericOp.getDpsInitOperand(0)->get().getType();
129+
if (!isa<MemRefType, RankedTensorType>(t0) ||
130+
!isa<MemRefType, RankedTensorType>(t1))
98131
return std::nullopt;
99-
return value->get();
132+
133+
// Check output is identity map. Injective function could also be
134+
// a permutation of indices and expressible in linalg.generic but
135+
// is not expressible for named broadcast op.
136+
auto dstMap = genericOp.getIndexingMapsArray()[1];
137+
if (!dstMap.isIdentity())
138+
return std::nullopt;
139+
140+
SmallVector<int64_t> position;
141+
auto srcMap = genericOp.getIndexingMapsArray()[0];
142+
143+
// Check input map is monotonically increasing DimIds.
144+
for (unsigned i = 0; i < srcMap.getNumResults(); ++i) {
145+
auto expr = llvm::dyn_cast<AffineDimExpr>(srcMap.getResults()[i]);
146+
if (!expr)
147+
return std::nullopt;
148+
int64_t pos = expr.getPosition();
149+
if (i > 0 && pos <= position[i - 1])
150+
return std::nullopt;
151+
position.push_back(expr.getPosition());
152+
}
153+
154+
SmallVector<int64_t> broadcastedDims;
155+
auto numDims = srcMap.getNumDims();
156+
for (auto dim : llvm::seq<int64_t>(0, numDims)) {
157+
if (!llvm::is_contained(position, dim))
158+
broadcastedDims.push_back(dim);
159+
}
160+
return broadcastedDims;
161+
}
162+
163+
//===----------------------------------------------------------------------===//
164+
// TranposeOpInterface implementation
165+
//===----------------------------------------------------------------------===//
166+
std::optional<SmallVector<int64_t>>
167+
linalg::isaTransposeOpInterface(GenericOp genericOp) {
168+
// Structural.
169+
if (!isAllParallel(genericOp) || !isSingleInputOutput(genericOp) ||
170+
!isSingleYieldOp(genericOp))
171+
return std::nullopt;
172+
173+
// mapping checks.
174+
auto mapRange = genericOp.getIndexingMapsArray();
175+
if (mapRange.size() != 2 || !mapRange.back().isIdentity() ||
176+
!mapRange.front().isPermutation())
177+
return std::nullopt;
178+
179+
SmallVector<int64_t> permutation;
180+
auto map = mapRange.front();
181+
for (unsigned i = 0; i < map.getNumResults(); ++i) {
182+
auto expr = llvm::cast<AffineDimExpr>(map.getResults()[i]);
183+
permutation.push_back(expr.getPosition());
184+
}
185+
return permutation;
100186
}
101187

102188
//===----------------------------------------------------------------------===//
@@ -106,8 +192,7 @@ static bool
106192
isaElemwiseSingleUnaryOrBinaryOpInterface(linalg::GenericOp genericOp,
107193
unsigned arity) {
108194
// Check all loops are parallel.
109-
if (genericOp.getNumParallelLoops() != genericOp.getNumLoops() ||
110-
genericOp.getNumLoops() < 1)
195+
if (!isAllParallel(genericOp) || genericOp.getNumLoops() < 1)
111196
return false;
112197

113198
// Check there are arity-inputs, 1-output and all are identity-maps.

mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -259,18 +259,43 @@ static FailureOr<LinalgOp> specializeLinalgContractions(RewriterBase &rewriter,
259259
//===----------------------------------------------------------------------===//
260260
FailureOr<LinalgOp> mlir::linalg::specializeGenericOp(RewriterBase &rewriter,
261261
GenericOp genericOp) {
262+
// Copy
262263
if (isaCopyOpInterface(genericOp)) {
263264
LinalgOp namedOp = rewriter.replaceOpWithNewOp<CopyOp>(
264265
genericOp, genericOp.getDpsInputs()[0], genericOp.getDpsInits()[0]);
265266
return namedOp;
266267
}
267268

269+
// Fill
268270
if (isaFillOpInterface(genericOp)) {
269271
LinalgOp namedOp = rewriter.replaceOpWithNewOp<FillOp>(
270272
genericOp, genericOp.getDpsInputs()[0], genericOp.getDpsInits()[0]);
271273
return namedOp;
272274
}
273275

276+
// Broadcast
277+
std::optional<SmallVector<int64_t>> equivalentToBroadcast =
278+
isaBroadcastOpInterface(genericOp);
279+
if (equivalentToBroadcast) {
280+
auto dims = *equivalentToBroadcast;
281+
LinalgOp namedOp = rewriter.replaceOpWithNewOp<BroadcastOp>(
282+
genericOp, genericOp.getDpsInputs()[0], genericOp.getDpsInits()[0],
283+
dims);
284+
return namedOp;
285+
}
286+
287+
// Transpose
288+
std::optional<SmallVector<int64_t>> equivalentToTranspose =
289+
isaTransposeOpInterface(genericOp);
290+
if (equivalentToTranspose) {
291+
auto permutation = *equivalentToTranspose;
292+
LinalgOp namedOp = rewriter.replaceOpWithNewOp<TransposeOp>(
293+
genericOp, genericOp.getDpsInputs()[0], genericOp.getDpsInits()[0],
294+
permutation);
295+
return namedOp;
296+
}
297+
298+
// Elementwise Unary
274299
if (isaElemwiseSingleUnaryOpInterface(genericOp)) {
275300
Operation *op = &genericOp.getBody()->front();
276301
if (isa<math::ExpOp>(op)) {
@@ -279,6 +304,7 @@ FailureOr<LinalgOp> mlir::linalg::specializeGenericOp(RewriterBase &rewriter,
279304
}
280305
}
281306

307+
// Elementwise Binary
282308
if (isaElemwiseSingleBinaryOpInterface(genericOp)) {
283309
bool swap = areBinOpsSwapped(genericOp);
284310
Operation *op = &genericOp.getBody()->front();
@@ -300,6 +326,7 @@ FailureOr<LinalgOp> mlir::linalg::specializeGenericOp(RewriterBase &rewriter,
300326
}
301327
}
302328

329+
// Contraction - e.g. matmul
303330
if (isaContractionOpInterface(genericOp)) {
304331
return specializeLinalgContractions(rewriter, genericOp);
305332
}
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
// RUN: mlir-opt %s -linalg-generalize-named-ops | mlir-opt --linalg-specialize-generic-ops | FileCheck %s
2+
3+
// CHECK-LABEL: broadcast_first_dimension
4+
// CHECK-SAME: %[[A:.+]]: tensor<?x?xf32>, %[[Out:.+]]: tensor<?x?x?xf32>)
5+
// CHECK-NOT: linalg.generic
6+
// CHECK: %broadcasted = linalg.broadcast ins(%[[A]] : tensor<?x?xf32>) outs(%[[Out]] : tensor<?x?x?xf32>) dimensions = [0]
7+
//
8+
func.func @broadcast_first_dimension(%A: tensor<?x?xf32>, %Out: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
9+
%res = linalg.broadcast ins(%A: tensor<?x?xf32>) outs(%Out: tensor<?x?x?xf32>) dimensions = [0]
10+
return %res : tensor<?x?x?xf32>
11+
}
12+
13+
// CHECK-LABEL: broadcast_mid_dimension
14+
// CHECK-SAME: %[[A:.+]]: tensor<3x5xf32>, %[[Out:.+]]: tensor<3x4x5xf32>)
15+
// CHECK-NOT: linalg.generic
16+
// CHECK: %broadcasted = linalg.broadcast ins(%[[A]] : tensor<3x5xf32>) outs(%[[Out]] : tensor<3x4x5xf32>) dimensions = [1]
17+
//
18+
func.func @broadcast_mid_dimension(%A: tensor<3x5xf32>, %Out: tensor<3x4x5xf32>) -> tensor<3x4x5xf32> {
19+
%res = linalg.broadcast ins(%A: tensor<3x5xf32>) outs(%Out: tensor<3x4x5xf32>) dimensions = [1]
20+
return %res : tensor<3x4x5xf32>
21+
}
22+
23+
24+
// CHECK-LABEL: broadcast_multiple_dimensions
25+
// CHECK-SAME: %[[A:.+]]: tensor<4x5x7xf32>, %[[Out:.+]]: tensor<3x4x5x6x7x8x9xf32>)
26+
// CHECK-NOT: linalg.generic
27+
// CHECK: %broadcasted = linalg.broadcast ins(%[[A]] : tensor<4x5x7xf32>) outs(%[[Out]] : tensor<3x4x5x6x7x8x9xf32>) dimensions = [0, 3, 5, 6]
28+
//
29+
func.func @broadcast_multiple_dimensions(%A: tensor<4x5x7xf32>, %Out: tensor<3x4x5x6x7x8x9xf32>) -> tensor<3x4x5x6x7x8x9xf32> {
30+
%res = linalg.broadcast ins(%A: tensor<4x5x7xf32>) outs(%Out: tensor<3x4x5x6x7x8x9xf32>) dimensions = [0,3,5,6]
31+
return %res : tensor<3x4x5x6x7x8x9xf32>
32+
}
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
// RUN: mlir-opt %s -linalg-generalize-named-ops | mlir-opt --linalg-specialize-generic-ops | FileCheck %s
2+
3+
// CHECK-LABEL: linalg_transpose
4+
// CHECK-SAME: %[[A:.+]]: tensor<16x64xf32>, %[[Out:.+]]: tensor<64x16xf32>
5+
// CHECK-NOT: linalg.generic
6+
// CHECK: %transposed = linalg.transpose ins(%[[A]] : tensor<16x64xf32>) outs(%[[Out]] : tensor<64x16xf32>) permutation = [1, 0]
7+
//
8+
func.func @linalg_transpose(%A: tensor<16x64xf32>, %Out: tensor<64x16xf32>) -> tensor<64x16xf32> {
9+
%res = linalg.transpose ins(%A: tensor<16x64xf32>) outs(%Out: tensor<64x16xf32>) permutation = [1,0]
10+
return %res : tensor<64x16xf32>
11+
}

mlir/test/Dialect/Linalg/transform-op-specialize.mlir

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,6 @@
44
#map1 = affine_map<(d0, d1) -> (d0)>
55
#map2 = affine_map<(d0, d1) -> (d1, d0)>
66

7-
func.func @broadcast_copy_expect_no_match(%arg0: memref<?xf32>, %arg1: memref<?x?xf32>) {
8-
// expected-note @below {{when applied to this op}}
9-
linalg.generic {
10-
indexing_maps = [#map1, #map],
11-
iterator_types = ["parallel", "parallel"]}
12-
ins(%arg0 : memref<?xf32>) outs(%arg1 : memref<?x?xf32>) {
13-
^bb0(%in: f32, %out: f32):
14-
linalg.yield %in : f32
15-
}
16-
return
17-
}
18-
197
func.func @not_a_copy_expect_no_match(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>) {
208
// expected-note @below {{when applied to this op}}
219
linalg.generic {

0 commit comments

Comments
 (0)