Skip to content

Commit e3aa239

Browse files
committed
[fixup] Additional edge-case
Signed-off-by: Artem Gindinson <[email protected]>
1 parent 0fe986e commit e3aa239

File tree

2 files changed

+17
-2
lines changed

2 files changed

+17
-2
lines changed

mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -139,8 +139,14 @@ mlir::getReassociationIndicesForCollapse(ArrayRef<int64_t> sourceShape,
139139
return std::nullopt;
140140
continue;
141141
}
142-
if (wasLastDimDynamic && isDynamic)
143-
return std::nullopt;
142+
// If the last 2 dimensions in the target were dynamic, the tail in the
143+
// source shape cannot contain a dynamic value. E.g. ?x?->? is valid,
144+
// however ?x?x10x?->?x? would be indeterminate.
145+
if (wasLastDimDynamic && numTargetDims > 1 &&
146+
targetShape[numTargetDims - 2] == ShapedType::kDynamic) {
147+
if (isDynamic)
148+
return std::nullopt;
149+
}
144150
// If the last target dimension is static, only source dimensions of 1 are
145151
// acceptable.
146152
if (!wasLastDimDynamic && !isOne)

mlir/unittests/Dialect/Utils/ReshapeOpsUtilsTest.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,10 @@ TEST(ReassociationIndicesForCollapse, DynamicTest) {
9292
EXPECT_EQ(getReassociationIndicesForCollapse({1, ShapedType::kDynamic, 1},
9393
{ShapedType::kDynamic}),
9494
makeOptionalIndices({{0, 1, 2}}));
95+
EXPECT_EQ(getReassociationIndicesForCollapse(
96+
{ShapedType::kDynamic, ShapedType::kDynamic, 1},
97+
{ShapedType::kDynamic, ShapedType::kDynamic}),
98+
makeOptionalIndices({{0}, {1, 2}}));
9599
EXPECT_EQ(getReassociationIndicesForCollapse(
96100
{1, ShapedType::kDynamic, ShapedType::kDynamic},
97101
{ShapedType::kDynamic, ShapedType::kDynamic}),
@@ -122,4 +126,9 @@ TEST(ReassociationIndicesForCollapse, DynamicTestFailure) {
122126
{ShapedType::kDynamic, ShapedType::kDynamic, ShapedType::kDynamic},
123127
{ShapedType::kDynamic, ShapedType::kDynamic}),
124128
std::nullopt);
129+
EXPECT_EQ(getReassociationIndicesForCollapse(
130+
{ShapedType::kDynamic, ShapedType::kDynamic, 10, 1,
131+
ShapedType::kDynamic},
132+
{ShapedType::kDynamic, ShapedType::kDynamic}),
133+
std::nullopt);
125134
}

0 commit comments

Comments
 (0)