Skip to content

Commit 7bfc219

Browse files
committed
update tests
1 parent b491463 commit 7bfc219

File tree

4 files changed

+52
-78
lines changed

4 files changed

+52
-78
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 18 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -5534,11 +5534,12 @@ void ShapeCastOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
55345534

55355535
/// Returns true if each element of 'a' is equal to the product of a contiguous
55365536
/// sequence of the elements of 'b'. Returns false otherwise.
5537-
static bool isValidExpandingShapeCast(ArrayRef<int64_t> a,
5538-
ArrayRef<int64_t> b) {
5537+
static bool isExpandingShapeCast(ArrayRef<int64_t> a, ArrayRef<int64_t> b) {
55395538
unsigned rankA = a.size();
55405539
unsigned rankB = b.size();
5541-
assert(rankA <= rankB);
5540+
if (rankA > rankB) {
5541+
return false;
5542+
}
55425543

55435544
auto isOne = [](int64_t v) { return v == 1; };
55445545

@@ -5565,35 +5566,34 @@ static bool isValidExpandingShapeCast(ArrayRef<int64_t> a,
55655566
if (j < rankB && llvm::all_of(b.slice(j), isOne))
55665567
j = rankB;
55675568
}
5568-
55695569
return i == rankA && j == rankB;
55705570
}
55715571

5572+
static bool isValidShapeCast(ArrayRef<int64_t> a, ArrayRef<int64_t> b) {
5573+
return isExpandingShapeCast(a, b) || isExpandingShapeCast(b, a);
5574+
}
5575+
55725576
static LogicalResult verifyVectorShapeCast(Operation *op,
55735577
VectorType sourceVectorType,
55745578
VectorType resultVectorType) {
55755579
// Check that element type is the same.
55765580
if (sourceVectorType.getElementType() != resultVectorType.getElementType())
55775581
return op->emitOpError("has different source and result element types");
5578-
ArrayRef<int64_t> lowRankShape = sourceVectorType.getShape();
5579-
ArrayRef<int64_t> highRankShape = resultVectorType.getShape();
5580-
if (lowRankShape.size() > highRankShape.size())
5581-
std::swap(lowRankShape, highRankShape);
5582+
ArrayRef<int64_t> inShape = sourceVectorType.getShape();
5583+
ArrayRef<int64_t> outShape = resultVectorType.getShape();
55825584

55835585
// Check that product of source dim sizes matches product of result dim sizes.
5584-
int64_t nLowRankElms =
5585-
std::accumulate(lowRankShape.begin(), lowRankShape.end(), 1LL,
5586-
std::multiplies<int64_t>{});
5587-
int64_t nHighRankElms =
5588-
std::accumulate(highRankShape.begin(), highRankShape.end(), 1LL,
5589-
std::multiplies<int64_t>{});
5590-
5591-
if (nLowRankElms != nHighRankElms) {
5586+
int64_t nInElms = std::accumulate(inShape.begin(), inShape.end(), 1LL,
5587+
std::multiplies<int64_t>{});
5588+
int64_t nOutElms = std::accumulate(outShape.begin(), outShape.end(), 1LL,
5589+
std::multiplies<int64_t>{});
5590+
5591+
if (nInElms != nOutElms) {
55925592
return op->emitOpError(
55935593
"has a different number of source and result elements");
55945594
}
55955595

5596-
if (!isValidExpandingShapeCast(lowRankShape, highRankShape)) {
5596+
if (!isValidShapeCast(inShape, outShape)) {
55975597
return op->emitOpError(
55985598
"is invalid (does not uniformly collapse or expand)");
55995599
}
@@ -5641,12 +5641,7 @@ OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
56415641
if (resultType == srcType)
56425642
return otherOp.getSource();
56435643

5644-
ArrayRef<int64_t> lowRankShape = srcType.getShape();
5645-
ArrayRef<int64_t> highRankShape = resultType.getShape();
5646-
if (lowRankShape.size() > highRankShape.size())
5647-
std::swap(lowRankShape, highRankShape);
5648-
5649-
if (!isValidExpandingShapeCast(lowRankShape, highRankShape))
5644+
if (!isValidShapeCast(srcType.getShape(), resultType.getShape()))
56505645
return {};
56515646

56525647
setOperand(otherOp.getSource());

mlir/test/Dialect/Vector/canonicalize.mlir

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1290,12 +1290,12 @@ func.func @extract_strided_broadcast4(%arg0: f32) -> vector<1x4xf32> {
12901290
// -----
12911291

12921292
// CHECK-LABEL: consecutive_shape_cast
1293-
// CHECK: %[[C:.*]] = vector.shape_cast %{{.*}} : vector<16xf16> to vector<4x4xf16>
1294-
// CHECK-NEXT: return %[[C]] : vector<4x4xf16>
1295-
func.func @consecutive_shape_cast(%arg0: vector<16xf16>) -> vector<4x4xf16> {
1293+
// CHECK: %[[C:.*]] = vector.shape_cast %{{.*}} : vector<16xf16> to vector<2x2x4xf16>
1294+
// CHECK-NEXT: return %[[C]] : vector<2x2x4xf16>
1295+
func.func @consecutive_shape_cast(%arg0: vector<16xf16>) -> vector<2x2x4xf16> {
12961296
%0 = vector.shape_cast %arg0 : vector<16xf16> to vector<2x8xf16>
1297-
%1 = vector.shape_cast %0 : vector<2x8xf16> to vector<4x4xf16>
1298-
return %1 : vector<4x4xf16>
1297+
%1 = vector.shape_cast %0 : vector<2x8xf16> to vector<2x2x4xf16>
1298+
return %1 : vector<2x2x4xf16>
12991299
}
13001300

13011301
// -----

mlir/test/Dialect/Vector/vector-shape-cast-lowering-scalable-vectors.mlir

Lines changed: 29 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -74,23 +74,23 @@ func.func @i8_1d_to_2d_last_dim_scalable(%arg0: vector<[32]xi8>) -> vector<4x[8]
7474

7575
// CHECK-LABEL: f32_permute_leading_non_scalable_dims
7676
// CHECK-SAME: %[[arg0:.*]]: vector<2x3x[4]xf32>
77-
func.func @f32_permute_leading_non_scalable_dims(%arg0: vector<2x3x[4]xf32>) -> vector<3x2x[4]xf32> {
78-
// CHECK-NEXT: %[[ub:.*]] = ub.poison : vector<3x2x[4]xf32>
77+
func.func @f32_permute_leading_non_scalable_dims(%arg0: vector<2x3x[4]xf32>) -> vector<1x6x[4]xf32> {
78+
// CHECK-NEXT: %[[ub:.*]] = ub.poison : vector<1x6x[4]xf32>
7979
// CHECK-NEXT: %[[subvec0:.*]] = vector.extract %[[arg0]][0, 0] : vector<[4]xf32> from vector<2x3x[4]xf32>
80-
// CHECK-NEXT: %[[res0:.*]] = vector.insert %[[subvec0]], %[[ub]] [0, 0] : vector<[4]xf32> into vector<3x2x[4]xf32>
80+
// CHECK-NEXT: %[[res0:.*]] = vector.insert %[[subvec0]], %[[ub]] [0, 0] : vector<[4]xf32> into vector<1x6x[4]xf32>
8181
// CHECK-NEXT: %[[subvec1:.*]] = vector.extract %[[arg0]][0, 1] : vector<[4]xf32> from vector<2x3x[4]xf32>
82-
// CHECK-NEXT: %[[res1:.*]] = vector.insert %[[subvec1]], %[[res0]] [0, 1] : vector<[4]xf32> into vector<3x2x[4]xf32>
82+
// CHECK-NEXT: %[[res1:.*]] = vector.insert %[[subvec1]], %[[res0]] [0, 1] : vector<[4]xf32> into vector<1x6x[4]xf32>
8383
// CHECK-NEXT: %[[subvec2:.*]] = vector.extract %[[arg0]][0, 2] : vector<[4]xf32> from vector<2x3x[4]xf32>
84-
// CHECK-NEXT: %[[res2:.*]] = vector.insert %[[subvec2]], %[[res1]] [1, 0] : vector<[4]xf32> into vector<3x2x[4]xf32>
84+
// CHECK-NEXT: %[[res2:.*]] = vector.insert %[[subvec2]], %[[res1]] [0, 2] : vector<[4]xf32> into vector<1x6x[4]xf32>
8585
// CHECK-NEXT: %[[subvec3:.*]] = vector.extract %[[arg0]][1, 0] : vector<[4]xf32> from vector<2x3x[4]xf32>
86-
// CHECK-NEXT: %[[res3:.*]] = vector.insert %[[subvec3]], %[[res2]] [1, 1] : vector<[4]xf32> into vector<3x2x[4]xf32>
86+
// CHECK-NEXT: %[[res3:.*]] = vector.insert %[[subvec3]], %[[res2]] [0, 3] : vector<[4]xf32> into vector<1x6x[4]xf32>
8787
// CHECK-NEXT: %[[subvec4:.*]] = vector.extract %[[arg0]][1, 1] : vector<[4]xf32> from vector<2x3x[4]xf32>
88-
// CHECK-NEXT: %[[res4:.*]] = vector.insert %[[subvec4]], %[[res3]] [2, 0] : vector<[4]xf32> into vector<3x2x[4]xf32>
88+
// CHECK-NEXT: %[[res4:.*]] = vector.insert %[[subvec4]], %[[res3]] [0, 4] : vector<[4]xf32> into vector<1x6x[4]xf32>
8989
// CHECK-NEXT: %[[subvec5:.*]] = vector.extract %[[arg0]][1, 2] : vector<[4]xf32> from vector<2x3x[4]xf32>
90-
// CHECK-NEXT: %[[res5:.*]] = vector.insert %[[subvec5]], %[[res4]] [2, 1] : vector<[4]xf32> into vector<3x2x[4]xf32>
91-
%res = vector.shape_cast %arg0: vector<2x3x[4]xf32> to vector<3x2x[4]xf32>
92-
// CHECK-NEXT: return %[[res5]] : vector<3x2x[4]xf32>
93-
return %res : vector<3x2x[4]xf32>
90+
// CHECK-NEXT: %[[res5:.*]] = vector.insert %[[subvec5]], %[[res4]] [0, 5] : vector<[4]xf32> into vector<1x6x[4]xf32>
91+
%res = vector.shape_cast %arg0: vector<2x3x[4]xf32> to vector<1x6x[4]xf32>
92+
// CHECK-NEXT: return %[[res5]] : vector<1x6x[4]xf32>
93+
return %res : vector<1x6x[4]xf32>
9494
}
9595

9696
// -----
@@ -117,48 +117,48 @@ func.func @f64_flatten_leading_non_scalable_dims(%arg0: vector<2x2x[2]xf64>) ->
117117

118118
// CHECK-LABEL: f32_reduce_trailing_scalable_dim
119119
// CHECK-SAME: %[[arg0:.*]]: vector<3x[4]xf32>
120-
func.func @f32_reduce_trailing_scalable_dim(%arg0: vector<3x[4]xf32>) -> vector<6x[2]xf32>
120+
func.func @f32_reduce_trailing_scalable_dim(%arg0: vector<3x[4]xf32>) -> vector<3x2x[2]xf32>
121121
{
122-
// CHECK-NEXT: %[[ub:.*]] = ub.poison : vector<6x[2]xf32>
122+
// CHECK-NEXT: %[[ub:.*]] = ub.poison : vector<3x2x[2]xf32>
123123
// CHECK-NEXT: %[[srcvec0:.*]] = vector.extract %[[arg0]][0] : vector<[4]xf32> from vector<3x[4]xf32>
124124
// CHECK-NEXT: %[[subvec0:.*]] = vector.scalable.extract %[[srcvec0]][0] : vector<[2]xf32> from vector<[4]xf32>
125-
// CHECK-NEXT: %[[res0:.*]] = vector.insert %[[subvec0]], %[[ub]] [0] : vector<[2]xf32> into vector<6x[2]xf32>
125+
// CHECK-NEXT: %[[res0:.*]] = vector.insert %[[subvec0]], %[[ub]] [0, 0] : vector<[2]xf32> into vector<3x2x[2]xf32>
126126
// CHECK-NEXT: %[[subvec1:.*]] = vector.scalable.extract %[[srcvec0]][2] : vector<[2]xf32> from vector<[4]xf32>
127-
// CHECK-NEXT: %[[res1:.*]] = vector.insert %[[subvec1]], %[[res0]] [1] : vector<[2]xf32> into vector<6x[2]xf32>
127+
// CHECK-NEXT: %[[res1:.*]] = vector.insert %[[subvec1]], %[[res0]] [0, 1] : vector<[2]xf32> into vector<3x2x[2]xf32>
128128
// CHECK-NEXT: %[[srcvec1:.*]] = vector.extract %[[arg0]][1] : vector<[4]xf32> from vector<3x[4]xf32>
129129
// CHECK-NEXT: %[[subvec2:.*]] = vector.scalable.extract %[[srcvec1]][0] : vector<[2]xf32> from vector<[4]xf32>
130-
// CHECK-NEXT: %[[res2:.*]] = vector.insert %[[subvec2]], %[[res1]] [2] : vector<[2]xf32> into vector<6x[2]xf32>
130+
// CHECK-NEXT: %[[res2:.*]] = vector.insert %[[subvec2]], %[[res1]] [1, 0] : vector<[2]xf32> into vector<3x2x[2]xf32>
131131
// CHECK-NEXT: %[[subvec3:.*]] = vector.scalable.extract %[[srcvec1]][2] : vector<[2]xf32> from vector<[4]xf32>
132-
// CHECK-NEXT: %[[res3:.*]] = vector.insert %[[subvec3]], %[[res2]] [3] : vector<[2]xf32> into vector<6x[2]xf32>
132+
// CHECK-NEXT: %[[res3:.*]] = vector.insert %[[subvec3]], %[[res2]] [1, 1] : vector<[2]xf32> into vector<3x2x[2]xf32>
133133
// CHECK-NEXT: %[[srcvec2:.*]] = vector.extract %[[arg0]][2] : vector<[4]xf32> from vector<3x[4]xf32>
134134
// CHECK-NEXT: %[[subvec4:.*]] = vector.scalable.extract %[[srcvec2]][0] : vector<[2]xf32> from vector<[4]xf32>
135-
// CHECK-NEXT: %[[res4:.*]] = vector.insert %[[subvec4]], %[[res3]] [4] : vector<[2]xf32> into vector<6x[2]xf32>
135+
// CHECK-NEXT: %[[res4:.*]] = vector.insert %[[subvec4]], %[[res3]] [2, 0] : vector<[2]xf32> into vector<3x2x[2]xf32>
136136
// CHECK-NEXT: %[[subvec5:.*]] = vector.scalable.extract %[[srcvec2]][2] : vector<[2]xf32> from vector<[4]xf32>
137-
// CHECK-NEXT: %[[res5:.*]] = vector.insert %[[subvec5]], %[[res4]] [5] : vector<[2]xf32> into vector<6x[2]xf32>
138-
%res = vector.shape_cast %arg0: vector<3x[4]xf32> to vector<6x[2]xf32>
139-
// CHECK-NEXT: return %[[res5]] : vector<6x[2]xf32>
140-
return %res: vector<6x[2]xf32>
137+
// CHECK-NEXT: %[[res5:.*]] = vector.insert %[[subvec5]], %[[res4]] [2, 1] : vector<[2]xf32> into vector<3x2x[2]xf32>
138+
%res = vector.shape_cast %arg0: vector<3x[4]xf32> to vector<3x2x[2]xf32>
139+
// CHECK-NEXT: return %[[res5]] : vector<3x2x[2]xf32>
140+
return %res: vector<3x2x[2]xf32>
141141
}
142142

143143
// -----
144144

145145
// CHECK-LABEL: f32_increase_trailing_scalable_dim
146-
// CHECK-SAME: %[[arg0:.*]]: vector<4x[2]xf32>
147-
func.func @f32_increase_trailing_scalable_dim(%arg0: vector<4x[2]xf32>) -> vector<2x[4]xf32>
146+
// CHECK-SAME: %[[arg0:.*]]: vector<2x2x[2]xf32>
147+
func.func @f32_increase_trailing_scalable_dim(%arg0: vector<2x2x[2]xf32>) -> vector<2x[4]xf32>
148148
{
149149
// CHECK-DAG: %[[ub0:.*]] = ub.poison : vector<2x[4]xf32>
150150
// CHECK-DAG: %[[ub1:.*]] = ub.poison : vector<[4]xf32>
151-
// CHECK-NEXT: %[[subvec0:.*]] = vector.extract %[[arg0]][0] : vector<[2]xf32> from vector<4x[2]xf32>
151+
// CHECK-NEXT: %[[subvec0:.*]] = vector.extract %[[arg0]][0, 0] : vector<[2]xf32> from vector<2x2x[2]xf32>
152152
// CHECK-NEXT: %[[resvec1:.*]] = vector.scalable.insert %[[subvec0]], %[[ub1]][0] : vector<[2]xf32> into vector<[4]xf32>
153-
// CHECK-NEXT: %[[subvec1:.*]] = vector.extract %[[arg0]][1] : vector<[2]xf32> from vector<4x[2]xf32>
153+
// CHECK-NEXT: %[[subvec1:.*]] = vector.extract %[[arg0]][0, 1] : vector<[2]xf32> from vector<2x2x[2]xf32>
154154
// CHECK-NEXT: %[[resvec2:.*]] = vector.scalable.insert %[[subvec1]], %[[resvec1]][2] : vector<[2]xf32> into vector<[4]xf32>
155155
// CHECK-NEXT: %[[res0:.*]] = vector.insert %[[resvec2]], %[[ub0]] [0] : vector<[4]xf32> into vector<2x[4]xf32>
156-
// CHECK-NEXT: %[[subvec3:.*]] = vector.extract %[[arg0]][2] : vector<[2]xf32> from vector<4x[2]xf32>
156+
// CHECK-NEXT: %[[subvec3:.*]] = vector.extract %[[arg0]][1, 0] : vector<[2]xf32> from vector<2x2x[2]xf32>
157157
// CHECK-NEXT: %[[resvec4:.*]] = vector.scalable.insert %[[subvec3]], %[[ub1]][0] : vector<[2]xf32> into vector<[4]xf32>
158-
// CHECK-NEXT: %[[subvec4:.*]] = vector.extract %[[arg0]][3] : vector<[2]xf32> from vector<4x[2]xf32>
158+
// CHECK-NEXT: %[[subvec4:.*]] = vector.extract %[[arg0]][1, 1] : vector<[2]xf32> from vector<2x2x[2]xf32>
159159
// CHECK-NEXT: %[[resvec5:.*]] = vector.scalable.insert %[[subvec4]], %[[resvec4]][2] : vector<[2]xf32> into vector<[4]xf32>
160160
// CHECK-NEXT: %[[res1:.*]] = vector.insert %[[resvec5]], %[[res0]] [1] : vector<[4]xf32> into vector<2x[4]xf32>
161-
%res = vector.shape_cast %arg0: vector<4x[2]xf32> to vector<2x[4]xf32>
161+
%res = vector.shape_cast %arg0: vector<2x2x[2]xf32> to vector<2x[4]xf32>
162162
// CHECK-NEXT: return %[[res1]] : vector<2x[4]xf32>
163163
return %res: vector<2x[4]xf32>
164164
}

mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -57,27 +57,6 @@ func.func @shape_casts(%a: vector<2x2xf32>) -> (vector<4xf32>, vector<2x2xf32>)
5757
return %r0, %1 : vector<4xf32>, vector<2x2xf32>
5858
}
5959

60-
// CHECK-LABEL: func @shape_cast_2d2d
61-
// CHECK-SAME: %[[A:.*]]: vector<3x2xf32>
62-
// CHECK: %[[UB:.*]] = ub.poison : vector<2x3xf32>
63-
// CHECK: %[[T0:.*]] = vector.extract %[[A]][0, 0] : f32 from vector<3x2xf32>
64-
// CHECK: %[[T1:.*]] = vector.insert %[[T0]], %[[UB]] [0, 0] : f32 into vector<2x3xf32>
65-
// CHECK: %[[T2:.*]] = vector.extract %[[A]][0, 1] : f32 from vector<3x2xf32>
66-
// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[T1]] [0, 1] : f32 into vector<2x3xf32>
67-
// CHECK: %[[T4:.*]] = vector.extract %[[A]][1, 0] : f32 from vector<3x2xf32>
68-
// CHECK: %[[T5:.*]] = vector.insert %[[T4]], %[[T3]] [0, 2] : f32 into vector<2x3xf32>
69-
// CHECK: %[[T6:.*]] = vector.extract %[[A]][1, 1] : f32 from vector<3x2xf32>
70-
// CHECK: %[[T7:.*]] = vector.insert %[[T6]], %[[T5]] [1, 0] : f32 into vector<2x3xf32>
71-
// CHECK: %[[T8:.*]] = vector.extract %[[A]][2, 0] : f32 from vector<3x2xf32>
72-
// CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T7]] [1, 1] : f32 into vector<2x3xf32>
73-
// CHECK: %[[T10:.*]] = vector.extract %[[A]][2, 1] : f32 from vector<3x2xf32>
74-
// CHECK: %[[T11:.*]] = vector.insert %[[T10]], %[[T9]] [1, 2] : f32 into vector<2x3xf32>
75-
// CHECK: return %[[T11]] : vector<2x3xf32>
76-
77-
func.func @shape_cast_2d2d(%arg0 : vector<3x2xf32>) -> vector<2x3xf32> {
78-
%s = vector.shape_cast %arg0: vector<3x2xf32> to vector<2x3xf32>
79-
return %s : vector<2x3xf32>
80-
}
8160

8261
// CHECK-LABEL: func @shape_cast_3d1d
8362
// CHECK-SAME: %[[A:.*]]: vector<1x3x2xf32>

0 commit comments

Comments
 (0)