Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
2ae4480
[mlir][tensor] Loosen restrictions on folding dynamic reshapes
AGindinson Apr 28, 2025
4f7c389
Merge branch 'main' into artem/upstream/reassoc-expand-of-collapse
AGindinson May 5, 2025
18da6fe
Merge branch 'main' into artem/upstream/reassoc-expand-of-collapse
AGindinson May 9, 2025
52ff4e0
[fixup] Algorithm rewrite
AGindinson May 9, 2025
1c85a68
[fixup] Add/expand unit tests
AGindinson May 9, 2025
0fe986e
[fixup] variable renaming
AGindinson May 9, 2025
e3aa239
[fixup] Additional edge-case
AGindinson May 9, 2025
16a932c
[WIP] Current tests pass
AGindinson May 20, 2025
dd36c47
[WIP] New tests
AGindinson May 20, 2025
114af4b
Merge branch 'main' into artem/upstream/reassoc-expand-of-collapse
AGindinson May 20, 2025
07ed33d
[fixup] Add scalar target tests & fix em
AGindinson May 20, 2025
6e61a52
[fixup] for self-induced unit dims problem
AGindinson May 21, 2025
b0e5c93
Merge branch 'main' into artem/upstream/reassoc-expand-of-collapse
AGindinson May 22, 2025
a6a18d6
[fixup] apply non-functional comments
AGindinson May 23, 2025
ce007de
[fixup] apply greedy logic suggestions
AGindinson May 23, 2025
66adf99
Merge branch 'main' into artem/upstream/reassoc-expand-of-collapse
AGindinson May 23, 2025
15caa29
[fixup] improve `getNonOverlappingIndicesWith(&rhs)`
AGindinson May 28, 2025
6b5d5cd
Merge branch 'main' into artem/upstream/reassoc-expand-of-collapse
AGindinson May 28, 2025
35cb397
Merge branch 'main' into reassoc-expand-of-collapse
AGindinson Jun 2, 2025
20e9a9f
Merge branch 'main' into reassoc-expand-of-collapse
AGindinson Jun 3, 2025
880b394
[fixup] Reduce auto usage, drop obsolete variable
AGindinson Jun 3, 2025
cc6df04
[fixup] Move a comment to the right place
AGindinson Jun 3, 2025
ea9161d
[fixup] Clarify some early-return cases
AGindinson Jun 3, 2025
54abd87
[fixup] Improve auto usage further
AGindinson Jun 3, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 57 additions & 46 deletions mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,59 +31,70 @@ mlir::getReassociationIndicesForReshape(ShapedType sourceType,
std::optional<SmallVector<ReassociationIndices>>
mlir::getReassociationIndicesForCollapse(ArrayRef<int64_t> sourceShape,
ArrayRef<int64_t> targetShape) {
if (sourceShape.size() <= targetShape.size())
unsigned numSourceDims = sourceShape.size(),
numTargetDims = targetShape.size();
if (numSourceDims <= numTargetDims)
return std::nullopt;
unsigned sourceDim = 0;
SmallVector<ReassociationIndices> reassociationMap;
reassociationMap.reserve(targetShape.size());

ReassociationIndices currIndices;
int64_t prodOfCollapsedDims = 1;
while (sourceDim < sourceShape.size()) {
unsigned targetDim = reassociationMap.size();
// If we have mapped all the target dimensions stop and handle the remaining
// tail of size-1 dimensions explicitly.
if (targetDim == targetShape.size())
break;
SmallVector<ReassociationIndices, 4> reassociationMap;
reassociationMap.reserve(numTargetDims);

unsigned sourceDim = 0, targetDim = 0;
for (; targetDim < numTargetDims; ++targetDim) {
int64_t currTargetShape = targetShape[targetDim];
while (sourceDim < (sourceShape.size() - 1) &&
sourceShape[sourceDim] != ShapedType::kDynamic &&
prodOfCollapsedDims * sourceShape[sourceDim] < currTargetShape) {
ReassociationIndices currIndices;
// 1. Target dimension is dynamic. Source shape should contain at least
// one dynamic dimension.
if (currTargetShape == ShapedType::kDynamic) {
// FIXME: We stop the search with the first dynamic dimension, while in
// fact, we can have a valid pattern like 2x?x?x4x8 -> ?x4x8. It becomes
// indeterministic altogether when we have neighboring dynamic dimensions
// in the target shape. Most of these patterns will be safely rejected,
// however we might achieve more correct folds by taking affine
// expressions into account, if these can be passed on by the call sites.
bool foundDynamic = false;
while (sourceDim < numSourceDims) {
currIndices.push_back(sourceDim);
if (sourceShape[sourceDim++] == ShapedType::kDynamic) {
foundDynamic = true;
break;
}
}
if (!foundDynamic)
return std::nullopt;

reassociationMap.push_back(currIndices);
continue;
}
// 2. Target dimension is static. The product of dimensions of the expanded
// shape should match the collapsed dimension shape.
int64_t prodOfCollapsedDims = 1;
bool reachedTargetDimSize = false;
while (sourceDim < numSourceDims) {
// Source shape cannot be dynamic if the target dim is static.
if (sourceShape[sourceDim] == ShapedType::kDynamic)
return std::nullopt;
prodOfCollapsedDims *= sourceShape[sourceDim];
currIndices.push_back(sourceDim++);
if (prodOfCollapsedDims > currTargetShape)
break;
else if (prodOfCollapsedDims == currTargetShape) {
currIndices.push_back(sourceDim++);
reachedTargetDimSize = true;
break;
} else // prodOfCollapsedDims < currTargetShape
currIndices.push_back(sourceDim++);
}

// If the current expanded dimension is dynamic, then the collapsed
// dimensions should also be dynamic and product of all previous unprocessed
// dimensions of the expanded shape should be 1.
if (sourceShape[sourceDim] == ShapedType::kDynamic &&
(currTargetShape != ShapedType::kDynamic || prodOfCollapsedDims != 1))
if (!reachedTargetDimSize)
return std::nullopt;

// If the collapsed dim is dynamic, the current expanded dim should also
// be dynamic.
if (currTargetShape == ShapedType::kDynamic &&
sourceShape[sourceDim] != ShapedType::kDynamic)
return std::nullopt;

// For static shapes, if the product of dimensions of the expanded shape
// should match the collapsed dimension shape.
if (prodOfCollapsedDims * sourceShape[sourceDim] != currTargetShape)
return std::nullopt;

currIndices.push_back(sourceDim++);
reassociationMap.emplace_back(ReassociationIndices{});
std::swap(reassociationMap.back(), currIndices);
prodOfCollapsedDims = 1;
reassociationMap.push_back(currIndices);
}
// All the dimensions in the target must have been processed.
if (reassociationMap.size() != targetShape.size())
return std::nullopt;
// Process any remaining entries in the source shape. They all need to be
// 1 or dynamic.
for (; sourceDim < sourceShape.size(); sourceDim++) {
if (sourceShape[sourceDim] != ShapedType::kDynamic &&
// Now that we've mapped all the target dimensions, process any remaining
// entries in the source shape explicitly. Either the last target dimension
// is dynamic, or all remaining source entries need to be 1 or dynamic. Same
// applies when target shape is empty (can be the case for subshape
// reassociations).
for (; sourceDim < numSourceDims; sourceDim++) {
if ((targetShape.empty() || targetShape.back() != ShapedType::kDynamic) &&
sourceShape[sourceDim] != ShapedType::kDynamic &&
sourceShape[sourceDim] != 1)
return std::nullopt;
// The map is empty when the target type is a scalar.
Expand Down
4 changes: 2 additions & 2 deletions mlir/test/Dialect/Linalg/simplify-pack-unpack.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -158,8 +158,8 @@ func.func @unpack_to_partial_slice(%arg0: tensor<8x32xf32>) -> tensor<255xf32> {
// -----

// CHECK-LABEL: func.func @unpack_dynamic
// CHECK-NOT: tensor.collapse
// CHECK: linalg.unpack
// CHECK: tensor.collapse
// CHECK-NOT: linalg.unpack
func.func @unpack_dynamic(%arg0: tensor<?x32xf32>) -> tensor<?xf32> {
%c32 = arith.constant 32 : index
%c0 = arith.constant 0 : index
Expand Down
24 changes: 20 additions & 4 deletions mlir/test/Dialect/Tensor/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1068,28 +1068,44 @@ func.func @fold_expand_of_collapse(%arg0 : tensor<3x4x4xf32>) -> tensor<3x4x4xf3

// -----

func.func @fold_expand_of_collapse_dynamic(%arg0 : tensor<?x4x?xf32>, %arg1: index, %arg2: index)
func.func @fold_expand_of_collapse_mixed_subshape(%arg0 : tensor<?x4x?xf32>, %arg1: index, %arg2: index)
-> tensor<?x4x?xf32> {
%0 = tensor.collapse_shape %arg0 [[0, 1], [2]]
: tensor<?x4x?xf32> into tensor<?x?xf32>
%1 = tensor.expand_shape %0 [[0, 1], [2]] output_shape [%arg1, 4, %arg2]
: tensor<?x?xf32> into tensor<?x4x?xf32>
return %1 : tensor<?x4x?xf32>
}
// CHECK-LABEL: @fold_expand_of_collapse_dynamic
// CHECK-LABEL: @fold_expand_of_collapse_mixed_subshape
// CHECK-NOT: tensor.{{.*}}_shape

// -----

func.func @no_fold_expand_of_collapse_dynamic(%arg0 : tensor<?x?x?xf32>, %arg1: index, %arg2: index, %arg3: index)
func.func @fold_expand_of_collapse_mixed_target_subshape(%arg0 : tensor<?x4x?x2xf32>, %arg1: index, %arg2: index)
-> tensor<?x4x?xf32> {
%0 = tensor.collapse_shape %arg0 [[0, 1], [2, 3]]
: tensor<?x4x?x2xf32> into tensor<?x?xf32>
%1 = tensor.expand_shape %0 [[0, 1], [2]] output_shape [%arg1, 4, %arg2]
: tensor<?x?xf32> into tensor<?x4x?xf32>
return %1 : tensor<?x4x?xf32>
}
// CHECK-LABEL: @fold_expand_of_collapse_mixed_target_subshape
// CHECK-NOT: tensor.expand_shape
// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %arg0 {{\[}}[0], [1], [2, 3]]
// CHECK-SAME: : tensor<?x4x?x2xf32> into tensor<?x4x?xf32>
// CHECK-NEXT: return %[[COLLAPSE]]

// -----

func.func @no_fold_expand_of_collapse_fully_dynamic(%arg0 : tensor<?x?x?xf32>, %arg1: index, %arg2: index, %arg3: index)
-> tensor<?x?x?xf32> {
%0 = tensor.collapse_shape %arg0 [[0, 1], [2]]
: tensor<?x?x?xf32> into tensor<?x?xf32>
%1 = tensor.expand_shape %0 [[0, 1], [2]] output_shape [%arg1, %arg2, %arg3]
: tensor<?x?xf32> into tensor<?x?x?xf32>
return %1 : tensor<?x?x?xf32>
}
// CHECK-LABEL: @no_fold_expand_of_collapse_dynamic
// CHECK-LABEL: @no_fold_expand_of_collapse_fully_dynamic
// CHECK: tensor.collapse_shape
// CHECK: %[[EXPAND:.+]] = tensor.expand_shape
// CHECK: return %[[EXPAND]]
Expand Down