Skip to content

Commit 510f3c7

Browse files
committed
[mlir] Fix ComposeExpandOfCollapseOp for dynamic case
Signed-off-by: Ian Wood <[email protected]>
1 parent a903271 commit 510f3c7

File tree

2 files changed

+19
-3
lines changed

2 files changed

+19
-3
lines changed

mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -387,11 +387,13 @@ struct ComposeExpandOfCollapseOp : public OpRewritePattern<ExpandOpTy> {
387387
auto resultSubShape =
388388
resultShape.slice(resultIndices.front(), resultIndices.size());
389389

390+
if (llvm::count_if(srcSubShape, ShapedType::isDynamic) >= 2)
391+
return std::nullopt;
392+
390393
if (srcSubShape.size() == resultSubShape.size()) {
391-
if (srcSubShape != resultSubShape ||
392-
llvm::count_if(srcSubShape, ShapedType::isDynamic) >= 2) {
394+
if (srcSubShape != resultSubShape)
393395
return std::nullopt;
394-
}
396+
395397
for (auto index : llvm::seq<int64_t>(0, srcSubShape.size())) {
396398
composedReassociation.emplace_back(1, srcIndices.front() + index);
397399
}

mlir/test/Dialect/Tensor/canonicalize.mlir

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1319,6 +1319,20 @@ func.func @compose_expand_of_collapse_dynamic(%arg0 : tensor<4x?x10x64x2xf16>, %
13191319

13201320
// -----
13211321

1322+
func.func @no_compose_collapse_of_expand_dynamic(%arg0 : tensor<?x8x128x?xf16>, %arg1: index) -> tensor<?x128x?xf16> {
1323+
%collapse = tensor.collapse_shape %arg0 [[0, 1, 2, 3]] : tensor<?x8x128x?xf16> into tensor<?xf16>
1324+
%expanded_19 = tensor.expand_shape %collapse [[0, 1, 2]] output_shape [%arg1, 8, %arg1] : tensor<?xf16> into tensor<?x128x?xf16>
1325+
return %expanded_19 : tensor<?x128x?xf16>
1326+
}
1327+
// CHECK-LABEL: func @no_compose_collapse_of_expand_dynamic
1328+
// CHECK-SAME: %[[ARG0:.+]]: tensor
1329+
// CHECK-SAME: %[[ARG1:.+]]: index
1330+
// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[ARG0]]
1331+
// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[COLLAPSE]]
1332+
// CHECK: return %[[EXPAND]]
1333+
1334+
// -----
1335+
13221336
// CHECK-LABEL: func @zero_rank_reshape_multi
13231337
func.func @zero_rank_reshape_multi(%arg0: tensor<f32>) -> tensor<f32> {
13241338
// CHECK: return %arg0

0 commit comments

Comments
 (0)