Skip to content

Commit d6f0385

Browse files
committed
Compute reassociation in dynamic cases.
Signed-off-by: Ian Wood <[email protected]>
1 parent b5b8a59 commit d6f0385

File tree

2 files changed

+27
-14
lines changed

2 files changed

+27
-14
lines changed

mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -49,27 +49,26 @@ mlir::getReassociationIndicesForCollapse(ArrayRef<int64_t> sourceShape,
4949
int64_t currTargetShape = targetShape[targetDim];
5050
while (sourceDim < (sourceShape.size() - 1) &&
5151
sourceShape[sourceDim] != ShapedType::kDynamic &&
52-
prodOfCollapsedDims * sourceShape[sourceDim] < currTargetShape) {
52+
(currTargetShape == ShapedType::kDynamic ||
53+
prodOfCollapsedDims * sourceShape[sourceDim] < currTargetShape)) {
5354
prodOfCollapsedDims *= sourceShape[sourceDim];
5455
currIndices.push_back(sourceDim++);
5556
}
5657

58+
if (sourceDim >= sourceShape.size())
59+
return std::nullopt;
60+
5761
// If the current expanded dimension is dynamic, then the collapsed
5862
// dimensions should also be dynamic and product of all previous unprocessed
5963
// dimensions of the expanded shape should be 1.
6064
if (sourceShape[sourceDim] == ShapedType::kDynamic &&
61-
(currTargetShape != ShapedType::kDynamic || prodOfCollapsedDims != 1))
62-
return std::nullopt;
63-
64-
// If the collapsed dim is dynamic, the current expanded dim should also
65-
// be dynamic.
66-
if (currTargetShape == ShapedType::kDynamic &&
67-
sourceShape[sourceDim] != ShapedType::kDynamic)
65+
currTargetShape != ShapedType::kDynamic)
6866
return std::nullopt;
6967

7068
// For static shapes, if the product of dimensions of the expanded shape
7169
// should match the collapsed dimension shape.
72-
if (prodOfCollapsedDims * sourceShape[sourceDim] != currTargetShape)
70+
if (sourceShape[sourceDim] != ShapedType::kDynamic &&
71+
prodOfCollapsedDims * sourceShape[sourceDim] != currTargetShape)
7372
return std::nullopt;
7473

7574
currIndices.push_back(sourceDim++);
@@ -315,11 +314,11 @@ SmallVector<Range> SliceFromCollapseHelper::getExtractSliceParams(
315314
// have proven that these are not sliced. In this case we just take
316315
// the full extent of each dimension in the reassociation list.
317316
if (linearizedDimensions[it.index()]) {
318-
llvm::append_range(
319-
offsetsSizesAndStrides,
320-
llvm::map_range(it.value(), [&](int64_t idx) -> Range {
321-
return {zeroAttr, collapseShapeInputShape[idx], oneAttr};
322-
}));
317+
llvm::append_range(offsetsSizesAndStrides,
318+
llvm::map_range(it.value(), [&](int64_t idx) -> Range {
319+
return {zeroAttr, collapseShapeInputShape[idx],
320+
oneAttr};
321+
}));
323322
continue;
324323
}
325324

mlir/test/Dialect/Tensor/canonicalize.mlir

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1191,6 +1191,20 @@ func.func @compose_expand_of_collapse_dynamic(%arg0 : tensor<4x?x10x64x2xf16>, %
11911191

11921192
// -----
11931193

1194+
func.func @compose_expand_of_collapse_dynamic_collapse(%arg0 : tensor<4x13x10x64x?xf16>, %arg1 : index) -> tensor<4x13x10x?xf16> {
1195+
%collapsed = tensor.collapse_shape %arg0 [[0, 1], [2], [3, 4]] : tensor<4x13x10x64x?xf16> into tensor<52x10x?xf16>
1196+
%expanded = tensor.expand_shape %collapsed [[0, 1], [2], [3]] output_shape [4, 13, 10, %arg1] : tensor<52x10x?xf16> into tensor<4x13x10x?xf16>
1197+
return %expanded : tensor<4x13x10x?xf16>
1198+
}
1199+
1200+
// CHECK-LABEL: func @compose_expand_of_collapse_dynamic_collapse
1201+
// CHECK-SAME: %[[ARG0:.+]]: tensor<4x13x10x64x?xf16>
1202+
// CHECK: %[[RESULT:.+]] = tensor.collapse_shape %[[ARG0]]
1203+
// CHECK-SAME: [0], [1], [2], [3, 4]
1204+
// CHECK: return %[[RESULT]]
1205+
1206+
// -----
1207+
11941208
// CHECK-LABEL: func @zero_rank_reshape_multi
11951209
func.func @zero_rank_reshape_multi(%arg0: tensor<f32>) -> tensor<f32> {
11961210
// CHECK: return %arg0

0 commit comments

Comments
 (0)