Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
2ae4480
[mlir][tensor] Loosen restrictions on folding dynamic reshapes
AGindinson Apr 28, 2025
4f7c389
Merge branch 'main' into artem/upstream/reassoc-expand-of-collapse
AGindinson May 5, 2025
18da6fe
Merge branch 'main' into artem/upstream/reassoc-expand-of-collapse
AGindinson May 9, 2025
52ff4e0
[fixup] Algorithm rewrite
AGindinson May 9, 2025
1c85a68
[fixup] Add/expand unit tests
AGindinson May 9, 2025
0fe986e
[fixup] variable renaming
AGindinson May 9, 2025
e3aa239
[fixup] Additional edge-case
AGindinson May 9, 2025
16a932c
[WIP] Current tests pass
AGindinson May 20, 2025
dd36c47
[WIP] New tests
AGindinson May 20, 2025
114af4b
Merge branch 'main' into artem/upstream/reassoc-expand-of-collapse
AGindinson May 20, 2025
07ed33d
[fixup] Add scalar target tests & fix em
AGindinson May 20, 2025
6e61a52
[fixup] for self-induced unit dims problem
AGindinson May 21, 2025
b0e5c93
Merge branch 'main' into artem/upstream/reassoc-expand-of-collapse
AGindinson May 22, 2025
a6a18d6
[fixup] apply non-functional comments
AGindinson May 23, 2025
ce007de
[fixup] apply greedy logic suggestions
AGindinson May 23, 2025
66adf99
Merge branch 'main' into artem/upstream/reassoc-expand-of-collapse
AGindinson May 23, 2025
15caa29
[fixup] improve `getNonOverlappingIndicesWith(&rhs)`
AGindinson May 28, 2025
6b5d5cd
Merge branch 'main' into artem/upstream/reassoc-expand-of-collapse
AGindinson May 28, 2025
35cb397
Merge branch 'main' into reassoc-expand-of-collapse
AGindinson Jun 2, 2025
20e9a9f
Merge branch 'main' into reassoc-expand-of-collapse
AGindinson Jun 3, 2025
880b394
[fixup] Reduce auto usage, drop obsolete variable
AGindinson Jun 3, 2025
cc6df04
[fixup] Move a comment to the right place
AGindinson Jun 3, 2025
ea9161d
[fixup] Clarify some early-return cases
AGindinson Jun 3, 2025
54abd87
[fixup] Improve auto usage further
AGindinson Jun 3, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
168 changes: 116 additions & 52 deletions mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,64 +31,128 @@ mlir::getReassociationIndicesForReshape(ShapedType sourceType,
std::optional<SmallVector<ReassociationIndices>>
mlir::getReassociationIndicesForCollapse(ArrayRef<int64_t> sourceShape,
ArrayRef<int64_t> targetShape) {
if (sourceShape.size() <= targetShape.size())
unsigned numSourceDims = sourceShape.size(),
numTargetDims = targetShape.size();
if (numSourceDims <= numTargetDims)
return std::nullopt;
unsigned sourceDim = 0;
SmallVector<ReassociationIndices> reassociationMap;
reassociationMap.reserve(targetShape.size());

ReassociationIndices currIndices;
int64_t prodOfCollapsedDims = 1;
while (sourceDim < sourceShape.size()) {
unsigned targetDim = reassociationMap.size();
// If we have mapped all the target dimensions stop and handle the remaining
// tail of size-1 dimensions explicitly.
if (targetDim == targetShape.size())
break;

int64_t currTargetShape = targetShape[targetDim];
while (sourceDim < (sourceShape.size() - 1) &&
sourceShape[sourceDim] != ShapedType::kDynamic &&
prodOfCollapsedDims * sourceShape[sourceDim] < currTargetShape) {
prodOfCollapsedDims *= sourceShape[sourceDim];
currIndices.push_back(sourceDim++);
SmallVector<ReassociationIndices, 4> reassociationMap;
reassociationMap.reserve(numTargetDims);

unsigned sourceDimIdx = 0, targetDimIdx = 0;
// Source dimensions iteration logic for static target dimensions.
// FIXME: Instead of lambda-capturing this function's source shape index "in
// place", consider refactoring this into a separate function.
auto collectSourceIndicesForStaticTargetDim =
[&](int64_t targetShape,
bool mayHaveOffset = false) -> FailureOr<ReassociationIndices> {
ReassociationIndices resultIndices;
int64_t prodOfCollapsedDims = 1;
bool reachedTargetDimSize = false;
for (; sourceDimIdx < numSourceDims; ++sourceDimIdx) {
// Source shape cannot be dynamic if the target dim is static.
if (sourceShape[sourceDimIdx] == ShapedType::kDynamic)
return failure();
prodOfCollapsedDims *= sourceShape[sourceDimIdx];
resultIndices.push_back(sourceDimIdx);
if (prodOfCollapsedDims > targetShape && !mayHaveOffset)
return failure();
while (prodOfCollapsedDims > targetShape) {
assert(!resultIndices.empty());
auto frontOffsetIdx = resultIndices.begin();
prodOfCollapsedDims /= sourceShape[*frontOffsetIdx];
resultIndices.erase(frontOffsetIdx);
}
if (prodOfCollapsedDims == targetShape) {
reachedTargetDimSize = true;
++sourceDimIdx;
break;
}
}
if (!reachedTargetDimSize)
return failure();
return resultIndices;
};
// Source dimensions iteration logic for dynamic target dimensions.
// FIXME: Instead of lambda-capturing this function's source shape index "in
// place", consider refactoring this into a separate function.
auto collectSourceIndicesForDynamicTargetDim =
[&](bool allowStaticNonOnes,
bool mapConsecutiveDynDims) -> FailureOr<ReassociationIndices> {
ReassociationIndices resultIndices;
bool foundFirstDynamic = false;
while (sourceDimIdx < numSourceDims) {
if (sourceShape[sourceDimIdx] == ShapedType::kDynamic) {
if (foundFirstDynamic && !mapConsecutiveDynDims)
break;
foundFirstDynamic |= true;
} else {
if (foundFirstDynamic)
break;
else if (sourceShape[sourceDimIdx] > 1 && !allowStaticNonOnes)
return failure();
}
resultIndices.push_back(sourceDimIdx++);
}
if (!foundFirstDynamic)
return failure();
return resultIndices;
};
// Iterate over target shape.
bool wasLastDimDynamic = false;
for (; targetDimIdx < numTargetDims; ++targetDimIdx) {
int64_t currTargetShape = targetShape[targetDimIdx];
if (currTargetShape != ShapedType::kDynamic) {
unsigned sourceDimAtStart = sourceDimIdx;
auto indices = collectSourceIndicesForStaticTargetDim(
currTargetShape, /*mayHaveOffset=*/wasLastDimDynamic);
if (failed(indices))
return std::nullopt;
if (wasLastDimDynamic) {
assert(!reassociationMap.empty());
auto &previousIndices = reassociationMap.back();
for (; sourceDimAtStart < indices->front(); ++sourceDimAtStart)
previousIndices.push_back(sourceDimAtStart);
}
reassociationMap.push_back(*indices);
wasLastDimDynamic = false;
continue;
}

// If the current expanded dimension is dynamic, then the collapsed
// dimensions should also be dynamic and product of all previous unprocessed
// dimensions of the expanded shape should be 1.
if (sourceShape[sourceDim] == ShapedType::kDynamic &&
(currTargetShape != ShapedType::kDynamic || prodOfCollapsedDims != 1))
return std::nullopt;

// If the collapsed dim is dynamic, the current expanded dim should also
// be dynamic.
if (currTargetShape == ShapedType::kDynamic &&
sourceShape[sourceDim] != ShapedType::kDynamic)
return std::nullopt;

// For static shapes, if the product of dimensions of the expanded shape
// should match the collapsed dimension shape.
if (prodOfCollapsedDims * sourceShape[sourceDim] != currTargetShape)
bool isNextDimDynamic =
targetDimIdx + 1 < numTargetDims &&
targetShape[targetDimIdx + 1] == ShapedType::kDynamic;
auto indices = collectSourceIndicesForDynamicTargetDim(
/*allowStaticNonOnes=*/!wasLastDimDynamic,
/*mapConsecutiveDynDims=*/!wasLastDimDynamic && !isNextDimDynamic);
if (failed(indices))
return std::nullopt;

currIndices.push_back(sourceDim++);
reassociationMap.emplace_back(ReassociationIndices{});
std::swap(reassociationMap.back(), currIndices);
prodOfCollapsedDims = 1;
reassociationMap.push_back(*indices);
wasLastDimDynamic = true;
}
// All the dimensions in the target must have been processed.
if (reassociationMap.size() != targetShape.size())
return std::nullopt;
// Process any remaining entries in the source shape. They all need to be
// 1 or dynamic.
for (; sourceDim < sourceShape.size(); sourceDim++) {
if (sourceShape[sourceDim] != ShapedType::kDynamic &&
sourceShape[sourceDim] != 1)
// Now that we've mapped all the target dimensions, process any remaining
// entries in the source shape explicitly.
for (; sourceDimIdx < numSourceDims; sourceDimIdx++) {
const bool isOne = sourceShape[sourceDimIdx] == 1,
isDynamic = sourceShape[sourceDimIdx] == ShapedType::kDynamic;
if (targetShape.empty()) {
if (!isOne && !isDynamic)
return std::nullopt;
continue;
}
// If the last 2 dimensions in the target were dynamic, the tail in the
// source shape cannot contain a dynamic value. E.g. ?x?->? is valid,
// however ?x?x10x?->?x? would be indeterminate.
if (wasLastDimDynamic && numTargetDims > 1 &&
targetShape[numTargetDims - 2] == ShapedType::kDynamic) {
if (isDynamic)
return std::nullopt;
}
// If the last target dimension is static, only source dimensions of 1 are
// acceptable.
if (!wasLastDimDynamic && !isOne)
return std::nullopt;
// The map is empty when the target type is a scalar.
if (!reassociationMap.empty())
reassociationMap.back().push_back(sourceDim);
assert(!reassociationMap.empty());
reassociationMap.back().push_back(sourceDimIdx);
}
return reassociationMap;
}
Expand Down
4 changes: 2 additions & 2 deletions mlir/test/Dialect/Linalg/simplify-pack-unpack.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -158,8 +158,8 @@ func.func @unpack_to_partial_slice(%arg0: tensor<8x32xf32>) -> tensor<255xf32> {
// -----

// CHECK-LABEL: func.func @unpack_dynamic
// CHECK-NOT: tensor.collapse
// CHECK: linalg.unpack
// CHECK: tensor.collapse
// CHECK-NOT: linalg.unpack
func.func @unpack_dynamic(%arg0: tensor<?x32xf32>) -> tensor<?xf32> {
%c32 = arith.constant 32 : index
%c0 = arith.constant 0 : index
Expand Down
39 changes: 35 additions & 4 deletions mlir/test/Dialect/Tensor/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1068,28 +1068,59 @@ func.func @fold_expand_of_collapse(%arg0 : tensor<3x4x4xf32>) -> tensor<3x4x4xf3

// -----

func.func @fold_expand_of_collapse_dynamic(%arg0 : tensor<?x4x?xf32>, %arg1: index, %arg2: index)
func.func @fold_expand_of_collapse_mixed_subshape(%arg0 : tensor<?x4x?xf32>, %arg1: index, %arg2: index)
-> tensor<?x4x?xf32> {
%0 = tensor.collapse_shape %arg0 [[0, 1], [2]]
: tensor<?x4x?xf32> into tensor<?x?xf32>
%1 = tensor.expand_shape %0 [[0, 1], [2]] output_shape [%arg1, 4, %arg2]
: tensor<?x?xf32> into tensor<?x4x?xf32>
return %1 : tensor<?x4x?xf32>
}
// CHECK-LABEL: @fold_expand_of_collapse_dynamic
// CHECK-LABEL: @fold_expand_of_collapse_mixed_subshape
// CHECK-NOT: tensor.{{.*}}_shape

// -----

func.func @no_fold_expand_of_collapse_dynamic(%arg0 : tensor<?x?x?xf32>, %arg1: index, %arg2: index, %arg3: index)
func.func @fold_expand_of_collapse_mixed_target_subshape(%arg0 : tensor<?x4x?x2xf32>, %arg1: index, %arg2: index)
-> tensor<?x4x?xf32> {
%0 = tensor.collapse_shape %arg0 [[0, 1], [2, 3]]
: tensor<?x4x?x2xf32> into tensor<?x?xf32>
%1 = tensor.expand_shape %0 [[0, 1], [2]] output_shape [%arg1, 4, %arg2]
: tensor<?x?xf32> into tensor<?x4x?xf32>
return %1 : tensor<?x4x?xf32>
}
// CHECK-LABEL: @fold_expand_of_collapse_mixed_target_subshape
// CHECK-NOT: tensor.expand_shape
// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %arg0 {{\[}}[0], [1], [2, 3]]
// CHECK-SAME: : tensor<?x4x?x2xf32> into tensor<?x4x?xf32>
// CHECK-NEXT: return %[[COLLAPSE]]

// -----

func.func @no_fold_expand_of_collapse_fully_dynamic(%arg0 : tensor<?x?x?xf32>, %arg1: index, %arg2: index, %arg3: index)
-> tensor<?x?x?xf32> {
%0 = tensor.collapse_shape %arg0 [[0, 1], [2]]
: tensor<?x?x?xf32> into tensor<?x?xf32>
%1 = tensor.expand_shape %0 [[0, 1], [2]] output_shape [%arg1, %arg2, %arg3]
: tensor<?x?xf32> into tensor<?x?x?xf32>
return %1 : tensor<?x?x?xf32>
}
// CHECK-LABEL: @no_fold_expand_of_collapse_dynamic
// CHECK-LABEL: @no_fold_expand_of_collapse_fully_dynamic
// CHECK: tensor.collapse_shape
// CHECK: %[[EXPAND:.+]] = tensor.expand_shape
// CHECK: return %[[EXPAND]]

// -----

func.func @no_fold_expand_of_collapse_adjacent_dynamic(%arg0 : tensor<?x?x?xf32>, %arg1: index, %arg2: index)
-> tensor<?x?xf32> {
%0 = tensor.collapse_shape %arg0 [[0, 1, 2]]
: tensor<?x?x?xf32> into tensor<?xf32>
%1 = tensor.expand_shape %0 [[0, 1]] output_shape [%arg1, %arg2]
: tensor<?xf32> into tensor<?x?xf32>
return %1 : tensor<?x?xf32>
}
// CHECK-LABEL: @no_fold_expand_of_collapse_adjacent_dynamic
// CHECK: tensor.collapse_shape
// CHECK: %[[EXPAND:.+]] = tensor.expand_shape
// CHECK: return %[[EXPAND]]
Expand Down
1 change: 1 addition & 0 deletions mlir/unittests/Dialect/Utils/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
add_mlir_unittest(MLIRDialectUtilsTests
StructuredOpsUtilsTest.cpp
ReshapeOpsUtilsTest.cpp
IndexingUtilsTest.cpp
)
mlir_target_link_libraries(MLIRDialectUtilsTests
Expand Down
134 changes: 134 additions & 0 deletions mlir/unittests/Dialect/Utils/ReshapeOpsUtilsTest.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
//===- ReshapeOpsUtilsTest.cpp - ReshapeOpsUtils unit tests ---------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
#include "llvm/ADT/STLExtras.h"
#include "gtest/gtest.h"
#include <optional>

using namespace mlir;

/// Helper to make constructing
/// `std::optional<SmallVector<ReassociationIndices>>` more readable.
static std::optional<SmallVector<ReassociationIndices>>
makeOptionalIndices(std::initializer_list<ReassociationIndices> list) {
return std::optional<SmallVector<ReassociationIndices>>(list);
}

TEST(ReassociationIndicesForCollapse, StaticTest) {
EXPECT_EQ(getReassociationIndicesForCollapse({10, 20}, {200}),
makeOptionalIndices({{0, 1}}));
EXPECT_EQ(getReassociationIndicesForCollapse({10, 20, 30}, {10, 600}),
makeOptionalIndices({{0}, {1, 2}}));
EXPECT_EQ(getReassociationIndicesForCollapse({10, 20, 30}, {200, 30}),
makeOptionalIndices({{0, 1}, {2}}));
}

TEST(ReassociationIndicesForCollapse, StaticTestFailure) {
EXPECT_EQ(getReassociationIndicesForCollapse({10, 20}, {10}), std::nullopt);
EXPECT_EQ(getReassociationIndicesForCollapse({10, 20}, {10, 20}),
std::nullopt);
EXPECT_EQ(getReassociationIndicesForCollapse({10, 20, 30}, {200, 300}),
std::nullopt);
EXPECT_EQ(getReassociationIndicesForCollapse({10, 20, 30}, {1, 10, 20, 30}),
std::nullopt);
}

TEST(ReassociationIndicesForCollapse, StaticTestUnitDims) {
EXPECT_EQ(getReassociationIndicesForCollapse({10, 1}, {10}),
makeOptionalIndices({{0, 1}}));
EXPECT_EQ(getReassociationIndicesForCollapse({1, 20, 30}, {600}),
makeOptionalIndices({{0, 1, 2}}));
EXPECT_EQ(getReassociationIndicesForCollapse({1, 1, 1}, {1}),
makeOptionalIndices({{0, 1, 2}}));
EXPECT_EQ(getReassociationIndicesForCollapse({1, 1, 1}, {1, 1}),
makeOptionalIndices({{0}, {1, 2}}));
}

TEST(ReassociationIndicesForCollapse, DynamicTest) {
EXPECT_EQ(getReassociationIndicesForCollapse({ShapedType::kDynamic, 1},
{ShapedType::kDynamic}),
makeOptionalIndices({{0, 1}}));
EXPECT_EQ(getReassociationIndicesForCollapse({ShapedType::kDynamic, 1, 1},
{ShapedType::kDynamic}),
makeOptionalIndices({{0, 1, 2}}));
EXPECT_EQ(
getReassociationIndicesForCollapse(
{ShapedType::kDynamic, ShapedType::kDynamic}, {ShapedType::kDynamic}),
makeOptionalIndices({{0, 1}}));
EXPECT_EQ(getReassociationIndicesForCollapse(
{1, ShapedType::kDynamic, ShapedType::kDynamic},
{1, ShapedType::kDynamic}),
makeOptionalIndices({{0}, {1, 2}}));

EXPECT_EQ(getReassociationIndicesForCollapse({ShapedType::kDynamic, 10},
{ShapedType::kDynamic}),
makeOptionalIndices({{0, 1}}));
EXPECT_EQ(getReassociationIndicesForCollapse(
{1, ShapedType::kDynamic, ShapedType::kDynamic},
{ShapedType::kDynamic}),
makeOptionalIndices({{0, 1, 2}}));
EXPECT_EQ(getReassociationIndicesForCollapse({10, ShapedType::kDynamic},
{ShapedType::kDynamic}),
makeOptionalIndices({{0, 1}}));
EXPECT_EQ(getReassociationIndicesForCollapse({ShapedType::kDynamic, 10, 20},
{ShapedType::kDynamic, 20}),
makeOptionalIndices({{0, 1}, {2}}));
EXPECT_EQ(getReassociationIndicesForCollapse({10, ShapedType::kDynamic, 20},
{ShapedType::kDynamic, 20}),
makeOptionalIndices({{0, 1}, {2}}));
EXPECT_EQ(getReassociationIndicesForCollapse(
{ShapedType::kDynamic, 3, 2, 5, 2}, {ShapedType::kDynamic, 20}),
makeOptionalIndices({{0, 1}, {2, 3, 4}}));
EXPECT_EQ(getReassociationIndicesForCollapse(
{10, ShapedType::kDynamic, 20, ShapedType::kDynamic, 1},
{ShapedType::kDynamic, 20, ShapedType::kDynamic}),
makeOptionalIndices({{0, 1}, {2}, {3, 4}}));
EXPECT_EQ(getReassociationIndicesForCollapse({1, ShapedType::kDynamic, 1},
{ShapedType::kDynamic}),
makeOptionalIndices({{0, 1, 2}}));
EXPECT_EQ(getReassociationIndicesForCollapse(
{ShapedType::kDynamic, ShapedType::kDynamic, 1},
{ShapedType::kDynamic, ShapedType::kDynamic}),
makeOptionalIndices({{0}, {1, 2}}));
EXPECT_EQ(getReassociationIndicesForCollapse(
{1, ShapedType::kDynamic, ShapedType::kDynamic},
{ShapedType::kDynamic, ShapedType::kDynamic}),
makeOptionalIndices({{0, 1}, {2}}));
EXPECT_EQ(getReassociationIndicesForCollapse(
{ShapedType::kDynamic, 1, ShapedType::kDynamic},
{ShapedType::kDynamic, ShapedType::kDynamic}),
makeOptionalIndices({{0}, {1, 2}}));
}

TEST(ReassociationIndicesForCollapse, DynamicTestFailure) {
EXPECT_EQ(getReassociationIndicesForCollapse({ShapedType::kDynamic, 10, 20},
{ShapedType::kDynamic, 10}),
std::nullopt);
EXPECT_EQ(getReassociationIndicesForCollapse(
{ShapedType::kDynamic, 10, ShapedType::kDynamic},
{ShapedType::kDynamic, ShapedType::kDynamic}),
std::nullopt);
EXPECT_EQ(getReassociationIndicesForCollapse(
{20, ShapedType::kDynamic, 10, ShapedType::kDynamic},
{ShapedType::kDynamic, ShapedType::kDynamic}),
std::nullopt);
EXPECT_EQ(getReassociationIndicesForCollapse(
{ShapedType::kDynamic, 5, 3, 2, 2}, {ShapedType::kDynamic, 20}),
std::nullopt);
EXPECT_EQ(
getReassociationIndicesForCollapse(
{ShapedType::kDynamic, ShapedType::kDynamic, ShapedType::kDynamic},
{ShapedType::kDynamic, ShapedType::kDynamic}),
std::nullopt);
EXPECT_EQ(getReassociationIndicesForCollapse(
{ShapedType::kDynamic, ShapedType::kDynamic, 10, 1,
ShapedType::kDynamic},
{ShapedType::kDynamic, ShapedType::kDynamic}),
std::nullopt);
}