Skip to content

Commit 84dff32

Browse files
committed
remove checks for collapse / expand
1 parent f0a59c4 commit 84dff32

File tree

5 files changed

+32
-108
lines changed

5 files changed

+32
-108
lines changed

mlir/include/mlir/Dialect/Vector/IR/VectorOps.td

Lines changed: 3 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2244,18 +2244,8 @@ def Vector_ShapeCastOp :
22442244
Results<(outs AnyVectorOfAnyRank:$result)> {
22452245
let summary = "shape_cast casts between vector shapes";
22462246
let description = [{
2247-
The shape_cast operation casts between an n-D source vector shape and
2248-
a k-D result vector shape (the element type remains the same).
2249-
2250-
If reducing rank (n > k), result dimension sizes must be a product
2251-
of contiguous source dimension sizes.
2252-
If expanding rank (n < k), source dimensions must factor into a
2253-
contiguous sequence of destination dimension sizes.
2254-
Each source dim is expanded (or contiguous sequence of source dims combined)
2255-
in source dimension list order (i.e. 0 <= i < n), to produce a contiguous
2256-
sequence of result dims (or a single result dim), in result dimension list
2257-
order (i.e. 0 <= j < k). The product of all source dimension sizes and all
2258-
result dimension sizes must match.
2247+
The shape_cast operation casts from a source vector to a target vector,
2248+
retaining the element type and number of elements.
22592249

22602250
It is currently assumed that this operation does not require moving data,
22612251
and that it will be folded away before lowering vector operations.
@@ -2268,12 +2258,7 @@ def Vector_ShapeCastOp :
22682258
Example:
22692259

22702260
```mlir
2271-
// Example casting to a lower vector rank.
2272-
%1 = vector.shape_cast %0 : vector<5x1x4x3xf32> to vector<20x3xf32>
2273-
2274-
// Example casting to a higher vector rank.
2275-
%3 = vector.shape_cast %2 : vector<10x12x8xf32> to vector<5x2x3x4x8xf32>
2276-
2261+
%1 = vector.shape_cast %0 : vector<4x3xf32> to vector<3x2x2xf32>
22772262
```
22782263
}];
22792264
let extraClassDeclaration = [{

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 10 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -5505,48 +5505,18 @@ void ShapeCastOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
55055505
setResultRanges(getResult(), argRanges.front());
55065506
}
55075507

5508-
/// Returns true if each element of 'a' is equal to the product of a contiguous
5509-
/// sequence of the elements of 'b'. Returns false otherwise.
5510-
static bool isValidShapeCast(ArrayRef<int64_t> a, ArrayRef<int64_t> b) {
5511-
unsigned rankA = a.size();
5512-
unsigned rankB = b.size();
5513-
assert(rankA < rankB);
5514-
5515-
auto isOne = [](int64_t v) { return v == 1; };
5516-
5517-
// Special-case for n-D to 0-d shape cast. 'b' must be all ones to be shape
5518-
// casted to a 0-d vector.
5519-
if (rankA == 0 && llvm::all_of(b, isOne))
5520-
return true;
5521-
5522-
unsigned i = 0;
5523-
unsigned j = 0;
5524-
while (i < rankA && j < rankB) {
5525-
int64_t dimA = a[i];
5526-
int64_t dimB = 1;
5527-
while (dimB < dimA && j < rankB)
5528-
dimB *= b[j++];
5529-
if (dimA != dimB)
5530-
break;
5531-
++i;
5532-
5533-
// Handle the case when trailing dimensions are of size 1.
5534-
// Include them into the contiguous sequence.
5535-
if (i < rankA && llvm::all_of(a.slice(i), isOne))
5536-
i = rankA;
5537-
if (j < rankB && llvm::all_of(b.slice(j), isOne))
5538-
j = rankB;
5539-
}
5508+
LogicalResult ShapeCastOp::verify() {
5509+
auto sourceVectorType =
5510+
llvm::dyn_cast_or_null<VectorType>(getSource().getType());
5511+
auto resultVectorType =
5512+
llvm::dyn_cast_or_null<VectorType>(getResult().getType());
55405513

5541-
return i == rankA && j == rankB;
5542-
}
5514+
if (!sourceVectorType) return failure();
5515+
if (!resultVectorType) return failure();
55435516

5544-
static LogicalResult verifyVectorShapeCast(Operation *op,
5545-
VectorType sourceVectorType,
5546-
VectorType resultVectorType) {
55475517
// Check that element type is the same.
55485518
if (sourceVectorType.getElementType() != resultVectorType.getElementType())
5549-
return op->emitOpError("source/result vectors must have same element type");
5519+
return emitOpError("source/result vectors must have same element type");
55505520
auto sourceShape = sourceVectorType.getShape();
55515521
auto resultShape = resultVectorType.getShape();
55525522

@@ -5556,44 +5526,20 @@ static LogicalResult verifyVectorShapeCast(Operation *op,
55565526
int64_t resultDimProduct = std::accumulate(
55575527
resultShape.begin(), resultShape.end(), 1LL, std::multiplies<int64_t>{});
55585528
if (sourceDimProduct != resultDimProduct)
5559-
return op->emitOpError("source/result number of elements must match");
5560-
5561-
// Check that expanding/contracting rank cases.
5562-
unsigned sourceRank = sourceVectorType.getRank();
5563-
unsigned resultRank = resultVectorType.getRank();
5564-
if (sourceRank < resultRank) {
5565-
if (!isValidShapeCast(sourceShape, resultShape))
5566-
return op->emitOpError("invalid shape cast");
5567-
} else if (sourceRank > resultRank) {
5568-
if (!isValidShapeCast(resultShape, sourceShape))
5569-
return op->emitOpError("invalid shape cast");
5570-
}
5529+
return emitOpError("source/result number of elements must match");
55715530

55725531
// Check that (non-)scalability is preserved
55735532
int64_t sourceNScalableDims = sourceVectorType.getNumScalableDims();
55745533
int64_t resultNScalableDims = resultVectorType.getNumScalableDims();
55755534
if (sourceNScalableDims != resultNScalableDims)
5576-
return op->emitOpError("different number of scalable dims at source (")
5535+
return emitOpError("different number of scalable dims at source (")
55775536
<< sourceNScalableDims << ") and result (" << resultNScalableDims
55785537
<< ")";
55795538
sourceVectorType.getNumDynamicDims();
55805539

55815540
return success();
55825541
}
55835542

5584-
LogicalResult ShapeCastOp::verify() {
5585-
auto sourceVectorType =
5586-
llvm::dyn_cast_or_null<VectorType>(getSource().getType());
5587-
auto resultVectorType =
5588-
llvm::dyn_cast_or_null<VectorType>(getResult().getType());
5589-
5590-
// Check if source/result are of vector type.
5591-
if (sourceVectorType && resultVectorType)
5592-
return verifyVectorShapeCast(*this, sourceVectorType, resultVectorType);
5593-
5594-
return success();
5595-
}
5596-
55975543
OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
55985544

55995545
// No-op shape cast.
@@ -5609,15 +5555,6 @@ OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
56095555
VectorType srcType = otherOp.getSource().getType();
56105556
if (resultType == srcType)
56115557
return otherOp.getSource();
5612-
if (srcType.getRank() < resultType.getRank()) {
5613-
if (!isValidShapeCast(srcType.getShape(), resultType.getShape()))
5614-
return {};
5615-
} else if (srcType.getRank() > resultType.getRank()) {
5616-
if (!isValidShapeCast(resultType.getShape(), srcType.getShape()))
5617-
return {};
5618-
} else {
5619-
return {};
5620-
}
56215558
setOperand(otherOp.getSource());
56225559
return getResult();
56235560
}

mlir/test/Dialect/Vector/canonicalize.mlir

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -950,10 +950,9 @@ func.func @insert_no_fold_scalar_to_0d(%v: vector<f32>) -> vector<f32> {
950950

951951
// -----
952952

953-
// CHECK-LABEL: dont_fold_expand_collapse
954-
// CHECK: %[[A:.*]] = vector.shape_cast %{{.*}} : vector<1x1x64xf32> to vector<1x1x8x8xf32>
955-
// CHECK: %[[B:.*]] = vector.shape_cast %{{.*}} : vector<1x1x8x8xf32> to vector<8x8xf32>
956-
// CHECK: return %[[B]] : vector<8x8xf32>
953+
// CHECK-LABEL: fold_expand_collapse
954+
// CHECK: %[[A:.*]] = vector.shape_cast %{{.*}} : vector<1x1x64xf32> to vector<8x8xf32>
955+
// CHECK: return %[[A]] : vector<8x8xf32>
957956
func.func @dont_fold_expand_collapse(%arg0: vector<1x1x64xf32>) -> vector<8x8xf32> {
958957
%0 = vector.shape_cast %arg0 : vector<1x1x64xf32> to vector<1x1x8x8xf32>
959958
%1 = vector.shape_cast %0 : vector<1x1x8x8xf32> to vector<8x8xf32>

mlir/test/Dialect/Vector/invalid.mlir

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1145,19 +1145,6 @@ func.func @shape_cast_wrong_num_elements(%arg0 : vector<5x1x3x2xf32>) {
11451145

11461146
// -----
11471147

1148-
func.func @shape_cast_invalid_rank_reduction(%arg0 : vector<5x1x3x2xf32>) {
1149-
// expected-error@+1 {{invalid shape cast}}
1150-
%0 = vector.shape_cast %arg0 : vector<5x1x3x2xf32> to vector<2x15xf32>
1151-
}
1152-
1153-
// -----
1154-
1155-
func.func @shape_cast_invalid_rank_expansion(%arg0 : vector<15x2xf32>) {
1156-
// expected-error@+1 {{invalid shape cast}}
1157-
%0 = vector.shape_cast %arg0 : vector<15x2xf32> to vector<5x2x3x1xf32>
1158-
}
1159-
1160-
// -----
11611148

11621149
func.func @shape_cast_scalability_flag_is_dropped(%arg0 : vector<15x[2]xf32>) {
11631150
// expected-error@+1 {{different number of scalable dims at source (1) and result (0)}}

mlir/test/Dialect/Vector/ops.mlir

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -543,6 +543,22 @@ func.func @vector_print_on_scalar(%arg0: i64) {
543543
return
544544
}
545545

546+
// CHECK-LABEL: @shape_cast_valid_rank_reduction
547+
func.func @shape_cast_valid_rank_reduction(%arg0 : vector<5x1x3x2xf32>) {
548+
// CHECK: vector.shape_cast %{{.*}} : vector<5x1x3x2xf32> to vector<2x15xf32>
549+
%0 = vector.shape_cast %arg0 : vector<5x1x3x2xf32> to vector<2x15xf32>
550+
return
551+
}
552+
553+
554+
// CHECK-LABEL: @shape_cast_valid_rank_expansion
555+
func.func @shape_cast_valid_rank_expansion(%arg0 : vector<15x2xf32>) {
556+
// CHECK: vector.shape_cast %{{.*}} : vector<15x2xf32> to vector<5x2x3x1xf32>
557+
%0 = vector.shape_cast %arg0 : vector<15x2xf32> to vector<5x2x3x1xf32>
558+
return
559+
}
560+
561+
546562
// CHECK-LABEL: @shape_cast
547563
func.func @shape_cast(%arg0 : vector<5x1x3x2xf32>,
548564
%arg1 : vector<8x1xf32>,

0 commit comments

Comments
 (0)