Skip to content

Commit 0ac4821

Browse files
authored
[mlir][linalg] unfold projected permutation. (llvm#114704)
Patterns to decompose the input operand(s) of a linalg.generic that has a projected permutation` affine-map -- i.e. effectively a folded `transpose`, `broadcast`, or a mixture of two -- into explicit transpose and broadcast. This is useful for instance when trying to recognize named ops. email: [email protected]
1 parent c236dbc commit 0ac4821

File tree

5 files changed

+326
-0
lines changed

5 files changed

+326
-0
lines changed

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1820,6 +1820,10 @@ void populateBubbleUpExtractSliceOpPatterns(RewritePatternSet &patterns);
18201820
/// linalg.fill(%cst, tensor.extract_slice(%init)).
18211821
void populateSwapExtractSliceWithFillPatterns(RewritePatternSet &patterns);
18221822

1823+
/// Add patterns to make explicit broadcasts and transforms in the
1824+
/// input operands of a genericOp.
1825+
void populateDecomposeProjectedPermutationPatterns(RewritePatternSet &patterns);
1826+
18231827
/// Patterns to apply `splitReduction` below.
18241828
void populateSplitReductionPattern(
18251829
RewritePatternSet &patterns,

mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
3838
TilingInterfaceImpl.cpp
3939
Transforms.cpp
4040
TransposeConv2D.cpp
41+
DecomposeGenericByUnfoldingPermutation.cpp
4142
Vectorization.cpp
4243
WinogradConv2D.cpp
4344

Lines changed: 249 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,249 @@
1+
//===- DecomposeGenericByUnfoldingPermutation.cpp -------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
#include "mlir/Dialect/Affine/IR/AffineOps.h"
10+
#include "mlir/Dialect/Linalg/IR/Linalg.h"
11+
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
12+
#include <map>
13+
#include <optional>
14+
#include <utility>
15+
16+
using namespace mlir;
17+
using namespace mlir::linalg;
18+
19+
namespace {
20+
21+
/// This pattern decomposes the input operand(s) of a linalg.generic that has
22+
/// a `transpose`, `broadcast`, or a mixture of two, into explicit transpose
23+
/// and broadcast. Having them folded into the linalg.generic is a good
24+
/// optimization but sometimes we may want to unwrap, i.e., `unfold` them as
25+
/// explicit transpose and broadcast. This rewrite pattern helps do it for
26+
/// each input operand. This is useful for instance when trying to recognize
27+
/// named ops.
28+
///
29+
/// The transpose, broadcast, or mixture of both, are expressed in the affine
30+
/// map of the operand. Technically it is essentially `projected permutation`.
31+
///
32+
/// Example
33+
///
34+
/// ```mlir
35+
///
36+
/// #projection = affine_map<(d0, d1, d2, d3, d4) -> (d2, d3, d1)>
37+
/// #identity = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
38+
/// ...
39+
/// %res = linalg.generic
40+
/// { indexing_maps = [#projection, #identity, #identity],
41+
/// iterator_types = ["parallel", "parallel", "parallel",
42+
/// "parallel", "parallel"]}
43+
/// ins(%x, %y : tensor<7x8x9xf32>, tensor<5x9x7x8x10xf32>)
44+
/// outs(%z : tensor<5x9x7x8x10xf32>) {
45+
/// ^bb0(%in: f32, %in_1: f32, %out: f32):
46+
/// %div = arith.divf %in, %in_1 : f32
47+
/// linalg.yield %div : f32
48+
/// } -> tensor<5x9x7x8x10xf32>
49+
/// ```
50+
///
51+
/// In the above IR operand `%x` map is a projected-permutation. This can be
52+
/// unfolded as:
53+
///
54+
/// ```mlir
55+
/// ...
56+
/// %x_trans = linalg.transpose
57+
/// ins(%x : tensor<7x8x9xf32>)
58+
/// outs(%e1 : tensor<9x7x8xf32>) permutation = [2, 0, 1]
59+
/// ...
60+
/// %x_trans_bc = linalg.broadcast
61+
/// ins(%x_trans : tensor<9x7x8xf32>)
62+
/// outs(%e2 : tensor<5x9x7x8x10xf32>) dimensions = [0, 4]
63+
/// %2 = linalg.div
64+
/// ins(%x_trans_bc, %y :
65+
/// tensor<5x9x7x8x10xf32>, tensor<5x9x7x8x10xf32>)
66+
/// outs(%arg2 : tensor<5x9x7x8x10xf32>) -> tensor<5x9x7x8x10xf32>
67+
///
68+
/// Note that linalg.generic has been 'specialized' to linalg.div.
69+
///
70+
/// To unfold it, it is more optimal to transpose first and then do the
71+
/// broadcast. However, if transpose is done first, the permutation map needs
72+
/// to be expressed in terms of reduced dimension as broadcast hasn't happened
73+
/// yet. Also, the broadcast dimensions in a linalg.generic come from other
74+
/// operands (those not broadcasted along that particular dimension). We work
75+
/// this out by computing the convex-polyhedron shape of the linalg.generic
76+
/// iteration space from shapes of all the operands, both inputs and outputs.
77+
///
78+
struct DecomposeProjectedPermutation : public OpRewritePattern<GenericOp> {
79+
using OpRewritePattern<GenericOp>::OpRewritePattern;
80+
81+
LogicalResult matchAndRewrite(GenericOp genericOp,
82+
PatternRewriter &rewriter) const override;
83+
};
84+
85+
/// For the given `map`, determine what dimensions are transposed and what
86+
/// dimensions are broadcasted.
87+
/// Returns :
88+
/// transpose-permutation, broadcast-dimensions` (empty if not needed)
89+
///
90+
std::pair<SmallVector<int64_t>, SmallVector<int64_t>>
91+
computeTransposeBroadcast(AffineMap &map) {
92+
assert(map.isProjectedPermutation(false) && "not a projection");
93+
94+
// As the map is a projection it likely operates on a smaller set of
95+
// dimensions as far as the transpose is concerned (rest are broadcast).
96+
int64_t minorSize = map.getNumResults();
97+
98+
SmallVector<int64_t> minorResult;
99+
for (int64_t i = 0; i < minorSize; ++i) {
100+
auto expr = cast<AffineDimExpr>(map.getResults()[i]);
101+
minorResult.push_back(expr.getPosition());
102+
}
103+
104+
// If dims are not monotonically increasing then transpose is present.
105+
SmallVector<int64_t> sortedResMap(minorResult);
106+
std::sort(sortedResMap.begin(), sortedResMap.end());
107+
bool hasTranspose = !std::equal(minorResult.begin(), minorResult.end(),
108+
sortedResMap.begin(), sortedResMap.end());
109+
110+
// Walk the sorted map result to determine which dimensions are broadcasted.
111+
SmallVector<int64_t> broadcast;
112+
for (int64_t i = 0, j = 0; i < map.getNumInputs(); ++i) {
113+
if (j < minorSize && sortedResMap[j] == i) {
114+
j++;
115+
continue;
116+
}
117+
broadcast.push_back(i);
118+
}
119+
120+
SmallVector<int64_t> permutation;
121+
if (hasTranspose) {
122+
// Consider an operand `x : tensor<7x8x9>` of a genericOp that has
123+
// affine map `affine_map<(d0, d1, d2, d3, d4) -> (d2, d3, d1)>`
124+
// `x`s access is both transposed and broadcast. But when specifying
125+
// the `linalg.transpose(x : tensor<7x8x9>)` the dimensions need to be
126+
// specified as `affine_map<(d0,d1,d2) -> (d1, d2, d0)` instead of
127+
// refering to d3, d4. Therefore, re-base the transpose dimensions so
128+
// that they start from d0.
129+
permutation.resize(minorSize);
130+
std::map<int64_t, int64_t> minorMap;
131+
for (int64_t i = 0; i < minorSize; ++i)
132+
minorMap.insert({sortedResMap[i], i});
133+
134+
// Re-map the dimensions.
135+
SmallVector<int64_t> remappedResult(minorSize);
136+
for (int64_t i = 0; i < minorSize; ++i)
137+
remappedResult[i] = minorMap[minorResult[i]];
138+
139+
/// Calculate the permutation for the transpose.
140+
for (unsigned i = 0; i < minorSize; ++i) {
141+
permutation[remappedResult[i]] = i;
142+
}
143+
}
144+
return {permutation, broadcast};
145+
}
146+
147+
LogicalResult DecomposeProjectedPermutation::matchAndRewrite(
148+
GenericOp op, PatternRewriter &rewriter) const {
149+
if (!op.hasPureTensorSemantics() || op.isSingleInputOutput() ||
150+
op.isSingleYieldOp() || !op.isAllParallelLoops())
151+
return failure();
152+
153+
// If the map of an operand is not a `projected permutation` then
154+
// it cannot be decomposed to mere transpose and broadcast.
155+
// The requirement that all maps be `projected permutation` may be
156+
// over-restrictive but since we need to determine shape of the
157+
// iteration space as well, reject if any map violates assumption.
158+
for (auto &opOperand : op->getOpOperands()) {
159+
auto map = op.getMatchingIndexingMap(&opOperand);
160+
if (!map.isProjectedPermutation(false))
161+
return failure();
162+
}
163+
164+
// Decomposing linalg.generic involves creating `tensor.empty`
165+
// which can have dynamic shapes but then we would have to work
166+
// out which operand can supply that runtime-value (tensor.dim).
167+
// Leaving it as a future TODO.
168+
if (llvm::any_of(op->getOpOperands(), [](OpOperand &oper) {
169+
auto opType = cast<RankedTensorType>(oper.get().getType());
170+
return ShapedType::isDynamicShape(opType.getShape());
171+
}))
172+
return failure();
173+
174+
auto outputShape = op.getStaticLoopRanges();
175+
176+
auto loc = op.getLoc();
177+
bool isChanged = false;
178+
SmallVector<Value> newInitValues = op.getDpsInputs();
179+
SmallVector<AffineMap> newMap = op.getIndexingMapsArray();
180+
181+
// Walk over each input operand and unfold if it is transposed, broadcast
182+
// or mix of two via operand's affine-map.
183+
for (int64_t i = 0; i < op.getNumDpsInputs(); ++i) {
184+
auto &map = newMap[i];
185+
auto inputRTType = cast<RankedTensorType>(newInitValues[i].getType());
186+
auto elType = inputRTType.getElementType();
187+
188+
/// Nothing to do if map is already an identity.
189+
if (map.isIdentity())
190+
continue;
191+
192+
auto [permutation, broadcastedDims] = computeTransposeBroadcast(map);
193+
194+
// Does it need transpose?
195+
if (!permutation.empty()) {
196+
/// linalg.transpose permutes the dimensions of input using
197+
/// rule: dim(result, i) = dim(input, permutation[i])
198+
SmallVector<int64_t> transposedShape(map.getNumResults());
199+
for (int64_t i = 0; i < map.getNumResults(); ++i)
200+
transposedShape[i] = inputRTType.getShape()[permutation[i]];
201+
202+
Value emptyTensor =
203+
rewriter.create<tensor::EmptyOp>(loc, transposedShape, elType);
204+
205+
auto transposeOp = rewriter.create<TransposeOp>(loc, newInitValues[i],
206+
emptyTensor, permutation);
207+
newInitValues[i] = transposeOp->getResult(0);
208+
isChanged = true;
209+
}
210+
211+
// Does it require broadcast?
212+
if (!broadcastedDims.empty()) {
213+
assert(broadcastedDims.size() && "should have non size broadcast");
214+
Value emptyTensor = rewriter.create<tensor::EmptyOp>(
215+
loc, outputShape, inputRTType.getElementType());
216+
217+
auto broadcastOp = rewriter.create<linalg::BroadcastOp>(
218+
loc, newInitValues[i], emptyTensor, broadcastedDims);
219+
220+
newInitValues[i] = broadcastOp->getResult(0);
221+
isChanged = true;
222+
}
223+
newMap[i] = rewriter.getMultiDimIdentityMap(map.getNumDims());
224+
}
225+
226+
if (isChanged) {
227+
SmallVector<Value> operands = op->getOperands();
228+
ValueRange operandsRef(operands);
229+
230+
auto newOp = rewriter.create<linalg::GenericOp>(
231+
/*location=*/op.getLoc(),
232+
/*resultTensorTypes=*/op->getResultTypes(),
233+
/*inputs=*/newInitValues,
234+
/*outputs=*/operandsRef.drop_front(op.getNumDpsInputs()),
235+
/*indexingMaps=*/newMap,
236+
/*iteratorTypes=*/op.getIteratorTypesArray());
237+
238+
newOp.getRegion().takeBody(op->getRegion(0));
239+
rewriter.replaceOp(op, newOp->getResults());
240+
}
241+
return success();
242+
}
243+
244+
} // namespace
245+
246+
void mlir::linalg::populateDecomposeProjectedPermutationPatterns(
247+
RewritePatternSet &patterns) {
248+
patterns.insert<DecomposeProjectedPermutation>(patterns.getContext());
249+
}

mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -347,6 +347,7 @@ struct LinalgSpecializeGenericOpsPass
347347
void LinalgSpecializeGenericOpsPass::runOnOperation() {
348348
RewritePatternSet patterns(&getContext());
349349
populateLinalgGenericOpsSpecializationPatterns(patterns);
350+
populateDecomposeProjectedPermutationPatterns(patterns);
350351

351352
if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
352353
signalPassFailure();
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
// RUN: mlir-opt %s -split-input-file --linalg-specialize-generic-ops | FileCheck %s
2+
3+
#projection = affine_map<(d0, d1, d2, d3, d4) -> (d2, d3, d1)>
4+
#identity = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
5+
6+
func.func @transpose_and_broadcast(%x : tensor<7x8x9xf32>, %y: tensor<5x9x7x8x10xf32>, %z : tensor<5x9x7x8x10xf32>) -> tensor<5x9x7x8x10xf32> {
7+
%res = linalg.generic
8+
{ indexing_maps = [#projection, #identity, #identity], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]}
9+
ins(%x, %y : tensor<7x8x9xf32>, tensor<5x9x7x8x10xf32>) outs(%z : tensor<5x9x7x8x10xf32>) {
10+
^bb0(%in: f32, %in_1: f32, %out: f32):
11+
%div = arith.divf %in, %in_1 : f32
12+
linalg.yield %div : f32
13+
} -> tensor<5x9x7x8x10xf32>
14+
return %res : tensor<5x9x7x8x10xf32>
15+
}
16+
17+
// CHECK-LABEL: transpose_and_broadcast
18+
// CHECK-SAME: %[[X:.+]]: tensor<7x8x9xf32>, %[[Y:.+]]: tensor<5x9x7x8x10xf32>, %[[Z:.+]]: tensor<5x9x7x8x10xf32>) -> tensor<5x9x7x8x10xf32> {
19+
// CHECK: %[[E0:.+]] = tensor.empty() : tensor<9x7x8xf32>
20+
// CHECK: %[[X_trans:.+]] = linalg.transpose ins(%[[X]] : tensor<7x8x9xf32>) outs(%[[E0]] : tensor<9x7x8xf32>) permutation = [2, 0, 1]
21+
// CHECK: %[[E1:.+]] = tensor.empty() : tensor<5x9x7x8x10xf32>
22+
// CHECK: %[[X_trans_bc:.+]] = linalg.broadcast ins(%[[X_trans]] : tensor<9x7x8xf32>) outs(%[[E1]] : tensor<5x9x7x8x10xf32>) dimensions = [0, 4]
23+
// CHECK: {{.*}} = linalg.div ins(%[[X_trans_bc]], %[[Y]] : tensor<5x9x7x8x10xf32>, tensor<5x9x7x8x10xf32>) outs(%[[Z]] : tensor<5x9x7x8x10xf32>) -> tensor<5x9x7x8x10xf32>
24+
// CHECK-NOT: linalg.generic
25+
26+
// -----
27+
28+
#identity = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
29+
#transposed = affine_map<(d0, d1, d2) -> (d2, d0, d1)>
30+
31+
func.func @transpose_only(%x : tensor<32x2x16xf32>, %y: tensor<2x16x32xf32>, %z : tensor<2x16x32xf32>) -> tensor<2x16x32xf32> {
32+
%res = linalg.generic
33+
{ indexing_maps = [#transposed, #identity, #identity], iterator_types = ["parallel", "parallel", "parallel"]}
34+
ins(%x, %y : tensor<32x2x16xf32>, tensor<2x16x32xf32>)
35+
outs(%z : tensor<2x16x32xf32>) {
36+
^bb0(%in: f32, %in_1: f32, %out: f32):
37+
%div = arith.divf %in, %in_1 : f32
38+
linalg.yield %div : f32
39+
} -> tensor<2x16x32xf32>
40+
return %res : tensor<2x16x32xf32>
41+
}
42+
43+
// CHECK-LABEL: transpose_only
44+
// CHECK-SAME: %[[X:.+]]: tensor<32x2x16xf32>, %[[Y:.+]]: tensor<2x16x32xf32>, %[[Z:.+]]: tensor<2x16x32xf32>) -> tensor<2x16x32xf32> {
45+
// CHECK: %[[E0:.+]] = tensor.empty() : tensor<2x16x32xf32>
46+
// CHECK: %[[X_trans:.+]] = linalg.transpose ins(%[[X]] : tensor<32x2x16xf32>) outs(%[[E0]] : tensor<2x16x32xf32>) permutation = [1, 2, 0]
47+
// CHECK: {{.*}} = linalg.div ins(%[[X_trans]], %[[Y]] : tensor<2x16x32xf32>, tensor<2x16x32xf32>) outs(%[[Z]] : tensor<2x16x32xf32>) -> tensor<2x16x32xf32>
48+
// CHECK-NOT: linalg.generic
49+
50+
// -----
51+
52+
#identity = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
53+
#broadcast = affine_map<(d0, d1, d2) -> (d0, d2)>
54+
func.func @broadcast_only(%x : tensor<2x16x32xf32>, %y: tensor<2x32xf32>, %z : tensor<2x16x32xf32>) -> tensor<2x16x32xf32> {
55+
%res = linalg.generic
56+
{ indexing_maps = [#identity, #broadcast, #identity], iterator_types = ["parallel", "parallel", "parallel"]}
57+
ins(%x, %y : tensor<2x16x32xf32>, tensor<2x32xf32>)
58+
outs(%z : tensor<2x16x32xf32>) {
59+
^bb0(%in: f32, %in_1: f32, %out: f32):
60+
%div = arith.divf %in, %in_1 : f32
61+
linalg.yield %div : f32
62+
} -> tensor<2x16x32xf32>
63+
return %res : tensor<2x16x32xf32>
64+
}
65+
66+
// CHECK-LABEL: broadcast_only
67+
// CHECK-SAME: %[[X:.+]]: tensor<2x16x32xf32>, %[[Y:.+]]: tensor<2x32xf32>, %[[Z:.+]]: tensor<2x16x32xf32>) -> tensor<2x16x32xf32> {
68+
// CHECK: %[[E0:.+]] = tensor.empty() : tensor<2x16x32xf32>
69+
// CHECK: %[[X_bc:.+]] = linalg.broadcast ins(%[[Y]] : tensor<2x32xf32>) outs(%[[E0]] : tensor<2x16x32xf32>) dimensions = [1]
70+
// CHECK: {{.*}} = linalg.div ins(%[[X]], %[[X_bc]] : tensor<2x16x32xf32>, tensor<2x16x32xf32>) outs(%arg2 : tensor<2x16x32xf32>) -> tensor<2x16x32xf32>
71+
// CHECK-NOT: linalg.generic

0 commit comments

Comments
 (0)