Skip to content

Commit 697e7ea

Browse files
authored
fix(linag.pack): decompisition of OuterUnitDims Pattern (#16)
Given the following example: ``` module { func.func @main(%arg0: tensor<1x1x1x4x1xf32>, %arg1: tensor<1x1x4xf32>) -> tensor<1x1x1x4x1xf32> { %pack = linalg.pack %arg1 outer_dims_perm = [1, 2, 0] inner_dims_pos = [2, 0] inner_tiles = [4, 1] into %arg0 : tensor<1x1x4xf32> -> tensor<1x1x1x4x1xf32> return %pack : tensor<1x1x1x4x1xf32> } } ``` We would generate an invalid transpose operation because the calculated permutation would be `[0, 2, 0]` which is semantically incorrect. As the permutation must contain unique integers corresponding to the source tensor. The assumption that because all outer dimensions are 1 the transposition does not matter does not hold because the indices in `permutation` must be unique and match the source tensor indices. The following change attempts to amend that.
1 parent 9c357ad commit 697e7ea

File tree

2 files changed

+34
-8
lines changed

2 files changed

+34
-8
lines changed

mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1192,16 +1192,23 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
11921192
// %init = tensor.empty()
11931193
// %transposed_tile = linalg.transpose ins(%source_or_padded_source),
11941194
// outs(%init)
1195-
// Two assumptions are made:
1196-
// 1. All outer dims are 1 - the corresponding transposition doesn't matter.
1197-
// 2. Inner dims position correspond to the trailing `numTiles` dims.
1198-
SmallVector<int64_t> tilesPermNormalized =
1199-
getPackUnpackNormalizedPerm(srcRank, packOp.getInnerDimsPos());
1195+
// Assumptions made:
1196+
// 1. Inner dims position correspond to the trailing `numTiles` dims.
12001197
SmallVector<int64_t> srcPermForTranspose;
1201-
for (int64_t i = 0; i < (srcRank - numTiles); i++)
1198+
ArrayRef<int64_t> innerDimPos(packOp.getInnerDimsPos());
1199+
for (int64_t i = 0; i < srcRank; i++) {
1200+
// As we assume the trailing dimensions of the inner dim position correspond
1201+
// to the trailing indices of the transpose permutation, we need to
1202+
// calculate the remaining indicies of the transpose permutation. This is
1203+
// done by adding the indices not contained in the inner dimension position.
1204+
// For example if we have a source tensor of dimensions [0, 1, 2, 3]
1205+
// and inner dim position of [3, 0], the remaining indices are [1, 2].
1206+
// and the transpose will be [1, 2, 3, 0].
1207+
if (llvm::is_contained(innerDimPos, i))
1208+
continue;
12021209
srcPermForTranspose.push_back(i);
1203-
1204-
srcPermForTranspose.append(SmallVector<int64_t>(packOp.getInnerDimsPos()));
1210+
}
1211+
srcPermForTranspose.append(innerDimPos.begin(), innerDimPos.end());
12051212

12061213
LLVM_DEBUG(DBGS() << "Pack permutation: " << packOp << "\n"
12071214
<< "perm: " << llvm::interleaved(srcPermForTranspose)

mlir/test/Dialect/Linalg/decompose-pack.mlir

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,3 +229,22 @@ func.func @simple_KCRS_to_KRSCsr(%arg0: tensor<1x1x32x8xf32>, %arg1: tensor<1x1x
229229
// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[TRANSP]] into %[[DEST]]
230230
// CHECK-SAME: [0, 0, 0, 0, 0, 0] [1, 1, 1, 1, 8, 32] [1, 1, 1, 1, 1, 1]
231231
// CHECK: return %[[INSERT]]
232+
233+
// -----
234+
235+
func.func @pack_with_unit_outer_dims_and_unit_inner(%arg0: tensor<1x1x4xf32>, %arg1: tensor<1x1x1x4x1xf32>) -> tensor<1x1x1x4x1xf32> {
236+
%pack = linalg.pack %arg0 outer_dims_perm = [1, 2, 0] inner_dims_pos = [2, 0] inner_tiles = [4, 1] into %arg1 : tensor<1x1x4xf32> -> tensor<1x1x1x4x1xf32>
237+
return %pack : tensor<1x1x1x4x1xf32>
238+
}
239+
240+
// CHECK-LABEL: func.func @pack_with_unit_outer_dims_and_unit_inner
241+
// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]
242+
// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
243+
// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<1x4x1xf32>
244+
// CHECK: %[[TRANSP:.+]] = linalg.transpose
245+
// CHECK-SAME: ins(%[[SRC]] : tensor<1x1x4xf32>)
246+
// CHECK-SAME: outs(%[[EMPTY]] : tensor<1x4x1xf32>)
247+
// CHECK-SAME: permutation = [1, 2, 0]
248+
// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[TRANSP]] into %[[DEST]]
249+
// CHECK-SAME: [0, 0, 0, 0, 0] [1, 1, 1, 4, 1] [1, 1, 1, 1, 1] : tensor<1x4x1xf32> into tensor<1x1x1x4x1xf32>
250+
// CHECK: return %[[INSERT]]

0 commit comments

Comments
 (0)