Skip to content

Commit 5accf5d

Browse files
committed
Handle collapse into single element
Signed-off-by: Ian Wood <[email protected]>
1 parent 5ecce45 commit 5accf5d

File tree

2 files changed

+22
-5
lines changed

2 files changed

+22
-5
lines changed

mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,9 @@ mlir::getReassociationIndicesForCollapse(ArrayRef<int64_t> sourceShape,
3333
ArrayRef<int64_t> targetShape) {
3434
if (sourceShape.size() <= targetShape.size())
3535
return std::nullopt;
36+
if (targetShape.size() == 1)
37+
return SmallVector<ReassociationIndices>{
38+
llvm::to_vector(llvm::seq<int64_t>(0, sourceShape.size()))};
3639
unsigned sourceDim = 0;
3740
SmallVector<ReassociationIndices> reassociationMap;
3841
reassociationMap.reserve(targetShape.size());
@@ -315,11 +318,11 @@ SmallVector<Range> SliceFromCollapseHelper::getExtractSliceParams(
315318
// have proven that these are not sliced. In this case we just take
316319
// the full extent of each dimension in the reassociation list.
317320
if (linearizedDimensions[it.index()]) {
318-
llvm::append_range(
319-
offsetsSizesAndStrides,
320-
llvm::map_range(it.value(), [&](int64_t idx) -> Range {
321-
return {zeroAttr, collapseShapeInputShape[idx], oneAttr};
322-
}));
321+
llvm::append_range(offsetsSizesAndStrides,
322+
llvm::map_range(it.value(), [&](int64_t idx) -> Range {
323+
return {zeroAttr, collapseShapeInputShape[idx],
324+
oneAttr};
325+
}));
323326
continue;
324327
}
325328

mlir/test/Dialect/Tensor/canonicalize.mlir

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1191,6 +1191,20 @@ func.func @compose_expand_of_collapse_dynamic(%arg0 : tensor<4x?x10x64x2xf16>, %
11911191

11921192
// -----
11931193

1194+
func.func @compose_expand_of_collapse_dynamic_collapse(%arg0 : tensor<4x13x10x64x?xf16>, %arg1 : index) -> tensor<4x13x10x?xf16> {
1195+
%collapsed = tensor.collapse_shape %arg0 [[0, 1], [2], [3, 4]] : tensor<4x13x10x64x?xf16> into tensor<52x10x?xf16>
1196+
%expanded = tensor.expand_shape %collapsed [[0, 1], [2], [3]] output_shape [4, 13, 10, %arg1] : tensor<52x10x?xf16> into tensor<4x13x10x?xf16>
1197+
return %expanded : tensor<4x13x10x?xf16>
1198+
}
1199+
1200+
// CHECK-LABEL: func @compose_expand_of_collapse_dynamic_collapse
1201+
// CHECK-SAME: %[[ARG0:.+]]: tensor<4x13x10x64x?xf16>
1202+
// CHECK: %[[RESULT:.+]] = tensor.collapse_shape %[[ARG0]]
1203+
// CHECK-SAME: [0], [1], [2], [3, 4]
1204+
// CHECK: return %[[RESULT]]
1205+
1206+
// -----
1207+
11941208
// CHECK-LABEL: func @zero_rank_reshape_multi
11951209
func.func @zero_rank_reshape_multi(%arg0: tensor<f32>) -> tensor<f32> {
11961210
// CHECK: return %arg0

0 commit comments

Comments
 (0)