diff --git a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp index 0336423c57b1d..169f28cece4dc 100644 --- a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp +++ b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp @@ -33,6 +33,9 @@ mlir::getReassociationIndicesForCollapse(ArrayRef sourceShape, ArrayRef targetShape) { if (sourceShape.size() <= targetShape.size()) return std::nullopt; + if (targetShape.size() == 1) + return SmallVector{ + llvm::to_vector(llvm::seq(0, sourceShape.size()))}; unsigned sourceDim = 0; SmallVector reassociationMap; reassociationMap.reserve(targetShape.size()); @@ -315,11 +318,11 @@ SmallVector SliceFromCollapseHelper::getExtractSliceParams( // have proven that these are not sliced. In this case we just take // the full extent of each dimension in the reassociation list. if (linearizedDimensions[it.index()]) { - llvm::append_range( - offsetsSizesAndStrides, - llvm::map_range(it.value(), [&](int64_t idx) -> Range { - return {zeroAttr, collapseShapeInputShape[idx], oneAttr}; - })); + llvm::append_range(offsetsSizesAndStrides, + llvm::map_range(it.value(), [&](int64_t idx) -> Range { + return {zeroAttr, collapseShapeInputShape[idx], + oneAttr}; + })); continue; } diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir index 90cc0ca658ffb..bbbef2ebc9d2b 100644 --- a/mlir/test/Dialect/Tensor/canonicalize.mlir +++ b/mlir/test/Dialect/Tensor/canonicalize.mlir @@ -1191,6 +1191,20 @@ func.func @compose_expand_of_collapse_dynamic(%arg0 : tensor<4x?x10x64x2xf16>, % // ----- +func.func @compose_expand_of_collapse_dynamic_collapse(%arg0 : tensor<4x13x10x64x?xf16>, %arg1 : index) -> tensor<4x13x10x?xf16> { + %collapsed = tensor.collapse_shape %arg0 [[0, 1], [2], [3, 4]] : tensor<4x13x10x64x?xf16> into tensor<52x10x?xf16> + %expanded = tensor.expand_shape %collapsed [[0, 1], [2], [3]] output_shape [4, 13, 10, %arg1] : tensor<52x10x?xf16> into tensor<4x13x10x?xf16> + return %expanded : tensor<4x13x10x?xf16> +} + +// CHECK-LABEL: func @compose_expand_of_collapse_dynamic_collapse +// CHECK-SAME: %[[ARG0:.+]]: tensor<4x13x10x64x?xf16> +// CHECK: %[[RESULT:.+]] = tensor.collapse_shape %[[ARG0]] +// CHECK-SAME: [0], [1], [2], [3, 4] +// CHECK: return %[[RESULT]] + +// ----- + // CHECK-LABEL: func @zero_rank_reshape_multi func.func @zero_rank_reshape_multi(%arg0: tensor) -> tensor { // CHECK: return %arg0