Skip to content

Commit c1f4264

Browse files
committed
fix edge case where n=k (rank-preserving shape_cast)
1 parent 9a1ece2 commit c1f4264

File tree

4 files changed

+61
-42
lines changed

4 files changed

+61
-42
lines changed

mlir/include/mlir/Dialect/Vector/IR/VectorOps.td

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2244,18 +2244,19 @@ def Vector_ShapeCastOp :
22442244
Results<(outs AnyVectorOfAnyRank:$result)> {
22452245
let summary = "shape_cast casts between vector shapes";
22462246
let description = [{
2247-
The shape_cast operation casts between an n-D source vector shape and
2248-
a k-D result vector shape (the element type remains the same).
2247+
The shape_cast operation casts from an n-D source vector to a k-D result
2248+
vector. The element type remains the same, as does the number of elements
2249+
(product of dimensions).
2250+
2251+
If reducing or preserving rank (n >= k), all result dimension sizes must be
2252+
products of contiguous source dimension sizes. If expanding rank (n < k),
2253+
source dimensions must all factor into contiguous sequences of destination
2254+
dimension sizes.
22492255

2250-
If reducing rank (n > k), result dimension sizes must be a product
2251-
of contiguous source dimension sizes.
2252-
If expanding rank (n < k), source dimensions must factor into a
2253-
contiguous sequence of destination dimension sizes.
22542256
Each source dim is expanded (or contiguous sequence of source dims combined)
22552257
in source dimension list order (i.e. 0 <= i < n), to produce a contiguous
22562258
sequence of result dims (or a single result dim), in result dimension list
2257-
order (i.e. 0 <= j < k). The product of all source dimension sizes and all
2258-
result dimension sizes must match.
2259+
order (i.e. 0 <= j < k).
22592260

22602261
It is currently assumed that this operation does not require moving data,
22612262
and that it will be folded away before lowering vector operations.

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 33 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -5534,10 +5534,10 @@ void ShapeCastOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
55345534

55355535
/// Returns true if each element of 'a' is equal to the product of a contiguous
55365536
/// sequence of the elements of 'b'. Returns false otherwise.
5537-
static bool isValidShapeCast(ArrayRef<int64_t> a, ArrayRef<int64_t> b) {
5537+
static bool isValidExpandingShapeCast(ArrayRef<int64_t> a, ArrayRef<int64_t> b) {
55385538
unsigned rankA = a.size();
55395539
unsigned rankB = b.size();
5540-
assert(rankA < rankB);
5540+
assert(rankA <= rankB);
55415541

55425542
auto isOne = [](int64_t v) { return v == 1; };
55435543

@@ -5573,34 +5573,36 @@ static LogicalResult verifyVectorShapeCast(Operation *op,
55735573
VectorType resultVectorType) {
55745574
// Check that element type is the same.
55755575
if (sourceVectorType.getElementType() != resultVectorType.getElementType())
5576-
return op->emitOpError("source/result vectors must have same element type");
5577-
auto sourceShape = sourceVectorType.getShape();
5578-
auto resultShape = resultVectorType.getShape();
5576+
return op->emitOpError("has different source and result element types");
5577+
ArrayRef<int64_t> lowRankShape = sourceVectorType.getShape();
5578+
ArrayRef<int64_t> highRankShape = resultVectorType.getShape();
5579+
if (lowRankShape.size() > highRankShape.size())
5580+
std::swap(lowRankShape, highRankShape);
55795581

55805582
// Check that product of source dim sizes matches product of result dim sizes.
5581-
int64_t sourceDimProduct = std::accumulate(
5582-
sourceShape.begin(), sourceShape.end(), 1LL, std::multiplies<int64_t>{});
5583-
int64_t resultDimProduct = std::accumulate(
5584-
resultShape.begin(), resultShape.end(), 1LL, std::multiplies<int64_t>{});
5585-
if (sourceDimProduct != resultDimProduct)
5586-
return op->emitOpError("source/result number of elements must match");
5587-
5588-
// Check that expanding/contracting rank cases.
5589-
unsigned sourceRank = sourceVectorType.getRank();
5590-
unsigned resultRank = resultVectorType.getRank();
5591-
if (sourceRank < resultRank) {
5592-
if (!isValidShapeCast(sourceShape, resultShape))
5593-
return op->emitOpError("invalid shape cast");
5594-
} else if (sourceRank > resultRank) {
5595-
if (!isValidShapeCast(resultShape, sourceShape))
5596-
return op->emitOpError("invalid shape cast");
5583+
int64_t nLowRankElms =
5584+
std::accumulate(lowRankShape.begin(), lowRankShape.end(), 1LL,
5585+
std::multiplies<int64_t>{});
5586+
int64_t nHighRankElms =
5587+
std::accumulate(highRankShape.begin(), highRankShape.end(), 1LL,
5588+
std::multiplies<int64_t>{});
5589+
5590+
if (nLowRankElms != nHighRankElms) {
5591+
return op->emitOpError(
5592+
"has a different number of source and result elements");
5593+
}
5594+
5595+
if (!isValidExpandingShapeCast(lowRankShape, highRankShape)) {
5596+
return op->emitOpError(
5597+
"is invalid (does not uniformly collapse or expand)");
55975598
}
55985599

55995600
// Check that (non-)scalability is preserved
56005601
int64_t sourceNScalableDims = sourceVectorType.getNumScalableDims();
56015602
int64_t resultNScalableDims = resultVectorType.getNumScalableDims();
56025603
if (sourceNScalableDims != resultNScalableDims)
5603-
return op->emitOpError("different number of scalable dims at source (")
5604+
return op->emitOpError(
5605+
"has a different number of scalable dims at source (")
56045606
<< sourceNScalableDims << ") and result (" << resultNScalableDims
56055607
<< ")";
56065608
sourceVectorType.getNumDynamicDims();
@@ -5634,17 +5636,18 @@ OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
56345636

56355637
// Only allows valid transitive folding (expand/collapse dimensions).
56365638
VectorType srcType = otherOp.getSource().getType();
5639+
56375640
if (resultType == srcType)
56385641
return otherOp.getSource();
5639-
if (srcType.getRank() < resultType.getRank()) {
5640-
if (!isValidShapeCast(srcType.getShape(), resultType.getShape()))
5641-
return {};
5642-
} else if (srcType.getRank() > resultType.getRank()) {
5643-
if (!isValidShapeCast(resultType.getShape(), srcType.getShape()))
5644-
return {};
5645-
} else {
5642+
5643+
ArrayRef<int64_t> lowRankShape = srcType.getShape();
5644+
ArrayRef<int64_t> highRankShape = resultType.getShape();
5645+
if (lowRankShape.size() > highRankShape.size())
5646+
std::swap(lowRankShape, highRankShape);
5647+
5648+
if (!isValidExpandingShapeCast(lowRankShape, highRankShape))
56465649
return {};
5647-
}
5650+
56485651
setOperand(otherOp.getSource());
56495652
return getResult();
56505653
}

mlir/test/Dialect/Vector/invalid.mlir

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1132,28 +1132,35 @@ func.func @cannot_print_string_with_source_set(%vec: vector<[4]xf32>) {
11321132
// -----
11331133

11341134
func.func @shape_cast_wrong_element_type(%arg0 : vector<5x1x3x2xf32>) {
1135-
// expected-error@+1 {{op source/result vectors must have same element type}}
1135+
// expected-error@+1 {{op has different source and result element types}}
11361136
%0 = vector.shape_cast %arg0 : vector<5x1x3x2xf32> to vector<15x2xi32>
11371137
}
11381138

11391139
// -----
11401140

11411141
func.func @shape_cast_wrong_num_elements(%arg0 : vector<5x1x3x2xf32>) {
1142-
// expected-error@+1 {{op source/result number of elements must match}}
1142+
// expected-error@+1 {{op has a different number of source and result elements}}
11431143
%0 = vector.shape_cast %arg0 : vector<5x1x3x2xf32> to vector<10x2xf32>
11441144
}
11451145

11461146
// -----
11471147

1148+
func.func @shape_cast_invalid_rank_preservating(%arg0 : vector<3x2xf32>) {
1149+
// expected-error@+1 {{op is invalid (does not uniformly collapse or expand)}}
1150+
%0 = vector.shape_cast %arg0 : vector<3x2xf32> to vector<2x3xf32>
1151+
}
1152+
1153+
// -----
1154+
11481155
func.func @shape_cast_invalid_rank_reduction(%arg0 : vector<5x1x3x2xf32>) {
1149-
// expected-error@+1 {{invalid shape cast}}
1156+
// expected-error@+1 {{op is invalid (does not uniformly collapse or expand)}}
11501157
%0 = vector.shape_cast %arg0 : vector<5x1x3x2xf32> to vector<2x15xf32>
11511158
}
11521159

11531160
// -----
11541161

11551162
func.func @shape_cast_invalid_rank_expansion(%arg0 : vector<15x2xf32>) {
1156-
// expected-error@+1 {{invalid shape cast}}
1163+
// expected-error@+1 {{op is invalid (does not uniformly collapse or expand)}}
11571164
%0 = vector.shape_cast %arg0 : vector<15x2xf32> to vector<5x2x3x1xf32>
11581165
}
11591166

mlir/test/Dialect/Vector/ops.mlir

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -576,6 +576,14 @@ func.func @shape_cast_0d(%arg0 : vector<1x1x1x1xf32>) -> (vector<1x1x1x1xf32>) {
576576
return %1 : vector<1x1x1x1xf32>
577577
}
578578

579+
// CHECK-LABEL: @shape_cast_rank_preserving
580+
func.func @shape_cast_rank_preserving(%arg0 : vector<1x4xf32>) -> vector<4x1xf32> {
581+
582+
// CHECK: vector.shape_cast %{{.*}} : vector<1x4xf32> to vector<4x1xf32>
583+
%0 = vector.shape_cast %arg0 : vector<1x4xf32> to vector<4x1xf32>
584+
return %0 : vector<4x1xf32>
585+
}
586+
579587
// CHECK-LABEL: @bitcast
580588
func.func @bitcast(%arg0 : vector<5x1x3x2xf32>,
581589
%arg1 : vector<8x1xi32>,

0 commit comments

Comments
 (0)