Skip to content

Commit 296f805

Browse files
committed
Revise based on 3rd round review comments
1 parent e3373b8 commit 296f805

File tree

2 files changed

+18
-10
lines changed

2 files changed

+18
-10
lines changed

mlir/lib/Dialect/Linalg/Transforms/DecomposeGenericByUnfoldingPermutation.cpp

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,9 @@ struct DecomposeProjectedPermutation : public OpRewritePattern<GenericOp> {
9090
std::pair<SmallVector<int64_t>, SmallVector<int64_t>>
9191
computeTransposeBroadcast(AffineMap &map) {
9292
assert(map.isProjectedPermutation(false) && "not a projection");
93+
94+
// As the map is a projection it likely operates on a smaller set of
95+
// dimensions as far as the transpose is concerned (rest are broadcast).
9396
int64_t minorSize = map.getNumResults();
9497

9598
SmallVector<int64_t> minorResult;
@@ -116,13 +119,13 @@ computeTransposeBroadcast(AffineMap &map) {
116119

117120
SmallVector<int64_t> permutation;
118121
if (hasTranspose) {
119-
/// Consider an operand `x : tensor<7x8x9>` of a genericOp that has
120-
/// affine map `affine_map<(d0, d1, d2, d3, d4) -> (d2, d3, d1)>`
121-
/// `x`s access is both transposed and brodcast. But when specifying
122-
/// the `linalg.transpose(x : tensor<7x8x9>)` the dimensions need to be
123-
/// specified as `affine_map<(d0,d1,d2) -> (d1, d2, d0)` instead of
124-
/// refering to d3, d4. Therefore, re-base the transpose dimensions so
125-
/// that they start from d0.
122+
// Consider an operand `x : tensor<7x8x9>` of a genericOp that has
123+
// affine map `affine_map<(d0, d1, d2, d3, d4) -> (d2, d3, d1)>`
124+
// `x`s access is both transposed and broadcast. But when specifying
125+
// the `linalg.transpose(x : tensor<7x8x9>)` the dimensions need to be
126+
// specified as `affine_map<(d0,d1,d2) -> (d1, d2, d0)` instead of
127+
// refering to d3, d4. Therefore, re-base the transpose dimensions so
128+
// that they start from d0.
126129
permutation.resize(minorSize);
127130
std::map<int64_t, int64_t> minorMap;
128131
for (int64_t i = 0; i < minorSize; ++i)
@@ -147,14 +150,19 @@ LogicalResult DecomposeProjectedPermutation::matchAndRewrite(
147150
op.isSingleYieldOp() || !op.isAllParallelLoops())
148151
return failure();
149152

150-
// All maps need to be projected permutations.
153+
// If the map of an operand is not a `projected permutation` then
154+
// it cannot be decomposed to mere transpose and broadcast.
155+
// The requirement that all maps be `projected permutation` may be
156+
// over-restrictive but since we need to determine shape of the
157+
// iteration space as well, reject if any map violates assumption.
151158
for (auto &opOperand : op->getOpOperands()) {
152159
auto map = op.getMatchingIndexingMap(&opOperand);
153160
if (!map.isProjectedPermutation(false))
154161
return failure();
155162
}
156163

157-
// Currently we handle only static shapes.
164+
// Decomposing linalg.generic involves creating `tensor.empty`
165+
// which cannot have dnyamic shapes.
158166
for (auto &operand : op->getOpOperands()) {
159167
auto opType = cast<RankedTensorType>(operand.get().getType());
160168
for (auto size : opType.getShape())

mlir/test/Dialect/Linalg/decompose-generic-by-unfolding-projected-permutation.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
func.func @transpose_and_broadcast(%x : tensor<7x8x9xf32>, %y: tensor<5x9x7x8x10xf32>, %z : tensor<5x9x7x8x10xf32>) -> tensor<5x9x7x8x10xf32> {
77
%res = linalg.generic
8-
{ indexing_maps = [#projection, #identity, #identity], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]}
8+
{ indexing_maps = [#projection, #identity, #identity], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]}
99
ins(%x, %y : tensor<7x8x9xf32>, tensor<5x9x7x8x10xf32>) outs(%z : tensor<5x9x7x8x10xf32>) {
1010
^bb0(%in: f32, %in_1: f32, %out: f32):
1111
%div = arith.divf %in, %in_1 : f32

0 commit comments

Comments
 (0)