Skip to content

Commit 570b023

Browse files
committed
improve draft version, add canonicalizer
1 parent 5753fbc commit 570b023

File tree

5 files changed

+53
-41
lines changed

5 files changed

+53
-41
lines changed

mlir/include/mlir/Dialect/Vector/IR/VectorOps.td

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2244,8 +2244,8 @@ def Vector_ShapeCastOp :
22442244
Results<(outs AnyVectorOfAnyRank:$result)> {
22452245
let summary = "shape_cast casts between vector shapes";
22462246
let description = [{
2247-
The shape_cast operation casts from a source vector to a target vector,
2248-
retaining the element type and number of elements.
2247+
Casts to a vector with the same number of elements, element type, and
2248+
number of scalable dimensions.
22492249

22502250
It is currently assumed that this operation does not require moving data,
22512251
and that it will be folded away before lowering vector operations.
@@ -2255,10 +2255,11 @@ def Vector_ShapeCastOp :
22552255
2-D MLIR vector to a 1-D flattened LLVM vector.shape_cast lowering to LLVM
22562256
is supported in that particular case, for now.
22572257

2258-
Example:
2258+
Examples:
22592259

22602260
```mlir
22612261
%1 = vector.shape_cast %0 : vector<4x3xf32> to vector<3x2x2xf32>
2262+
%2 = vector.shape_cast %0 : vector<[2]x3x[4]xi8> to vector<3x[1]x[8]xi8>
22622263
```
22632264
}];
22642265
let extraClassDeclaration = [{

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 34 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -5506,65 +5506,66 @@ void ShapeCastOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
55065506
}
55075507

55085508
LogicalResult ShapeCastOp::verify() {
5509-
auto sourceVectorType =
5510-
llvm::dyn_cast_or_null<VectorType>(getSource().getType());
5511-
auto resultVectorType =
5512-
llvm::dyn_cast_or_null<VectorType>(getResult().getType());
55135509

5514-
if (!sourceVectorType)
5515-
return failure();
5516-
if (!resultVectorType)
5517-
return failure();
5510+
VectorType sourceType = getSourceVectorType();
5511+
VectorType resultType = getResultVectorType();
55185512

5519-
// Check that element type is the same.
5520-
if (sourceVectorType.getElementType() != resultVectorType.getElementType())
5521-
return emitOpError("source/result vectors must have same element type");
5522-
auto sourceShape = sourceVectorType.getShape();
5523-
auto resultShape = resultVectorType.getShape();
5513+
// Check that element type is preserved
5514+
if (sourceType.getElementType() != resultType.getElementType())
5515+
return emitOpError("has different source and result element types");
55245516

5525-
// Check that product of source dim sizes matches product of result dim sizes.
5526-
int64_t sourceDimProduct = std::accumulate(
5527-
sourceShape.begin(), sourceShape.end(), 1LL, std::multiplies<int64_t>{});
5528-
int64_t resultDimProduct = std::accumulate(
5529-
resultShape.begin(), resultShape.end(), 1LL, std::multiplies<int64_t>{});
5530-
if (sourceDimProduct != resultDimProduct)
5531-
return emitOpError("source/result number of elements must match");
5517+
// Check that number of elements is preserved
5518+
int64_t sourceNElms = sourceType.getNumElements();
5519+
int64_t resultNElms = resultType.getNumElements();
5520+
if (sourceNElms != resultNElms) {
5521+
return emitOpError() << "has different number of elements at source ("
5522+
<< sourceNElms << ") and result (" << resultNElms
5523+
<< ")";
5524+
}
55325525

55335526
// Check that (non-)scalability is preserved
5534-
int64_t sourceNScalableDims = sourceVectorType.getNumScalableDims();
5535-
int64_t resultNScalableDims = resultVectorType.getNumScalableDims();
5527+
int64_t sourceNScalableDims = sourceType.getNumScalableDims();
5528+
int64_t resultNScalableDims = resultType.getNumScalableDims();
55365529
if (sourceNScalableDims != resultNScalableDims)
5537-
return emitOpError("different number of scalable dims at source (")
5538-
<< sourceNScalableDims << ") and result (" << resultNScalableDims
5539-
<< ")";
5540-
sourceVectorType.getNumDynamicDims();
5530+
return emitOpError() << "has different number of scalable dims at source ("
5531+
<< sourceNScalableDims << ") and result ("
5532+
<< resultNScalableDims << ")";
55415533

55425534
return success();
55435535
}
55445536

55455537
OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
55465538

5539+
VectorType resultType = getType();
5540+
55475541
// No-op shape cast.
5548-
if (getSource().getType() == getType())
5542+
if (getSource().getType() == resultType)
55495543
return getSource();
55505544

5551-
VectorType resultType = getType();
5552-
5553-
// Canceling shape casts.
5545+
// Y = shape_cast(shape_cast(X)))
5546+
// -> X, if X and Y have same type
5547+
// -> shape_cast(X) otherwise.
55545548
if (auto otherOp = getSource().getDefiningOp<ShapeCastOp>()) {
5555-
5556-
// Only allows valid transitive folding (expand/collapse dimensions).
55575549
VectorType srcType = otherOp.getSource().getType();
55585550
if (resultType == srcType)
55595551
return otherOp.getSource();
55605552
setOperand(otherOp.getSource());
55615553
return getResult();
55625554
}
55635555

5564-
// Cancelling broadcast and shape cast ops.
5556+
// Y = shape_cast(broadcast(X))
5557+
// -> X, if X and Y have same type, else
5558+
// -> shape_cast(X) if X is a vector and the broadcast preserves
5559+
// number of elements.
55655560
if (auto bcastOp = getSource().getDefiningOp<BroadcastOp>()) {
55665561
if (bcastOp.getSourceType() == resultType)
55675562
return bcastOp.getSource();
5563+
if (auto bcastSrcType = dyn_cast<VectorType>(bcastOp.getSourceType())) {
5564+
if (bcastSrcType.getNumElements() == resultType.getNumElements()) {
5565+
setOperand(bcastOp.getSource());
5566+
return getResult();
5567+
}
5568+
}
55685569
}
55695570

55705571
// shape_cast(constant) -> constant

mlir/test/Dialect/Vector/canonicalize.mlir

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -972,6 +972,18 @@ func.func @fold_broadcast_shapecast(%arg0: vector<4xf32>) -> vector<4xf32> {
972972

973973
// -----
974974

975+
// CHECK-LABEL: func @fold_count_preserving_broadcast_shapecast
976+
// CHECK-SAME: (%[[V:.+]]: vector<4xf32>)
977+
// CHECK: %[[SHAPECAST:.*]] = vector.shape_cast %[[V]] : vector<4xf32> to vector<2x2xf32>
978+
// CHECK: return %[[SHAPECAST]] : vector<2x2xf32>
979+
func.func @fold_count_preserving_broadcast_shapecast(%arg0: vector<4xf32>) -> vector<2x2xf32> {
980+
%0 = vector.broadcast %arg0 : vector<4xf32> to vector<1x1x4xf32>
981+
%1 = vector.shape_cast %0 : vector<1x1x4xf32> to vector<2x2xf32>
982+
return %1 : vector<2x2xf32>
983+
}
984+
985+
// -----
986+
975987
// CHECK-LABEL: func @canonicalize_broadcast_shapecast_scalar
976988
// CHECK: vector.broadcast
977989
// CHECK-NOT: vector.shape_cast

mlir/test/Dialect/Vector/invalid.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1131,21 +1131,21 @@ func.func @cannot_print_string_with_source_set(%vec: vector<[4]xf32>) {
11311131

11321132
// -----
11331133

1134+
11341135
func.func @shape_cast_wrong_element_type(%arg0 : vector<5x1x3x2xf32>) {
1135-
// expected-error@+1 {{op source/result vectors must have same element type}}
1136+
// expected-error@+1 {{'vector.shape_cast' op has different source and result element types}}
11361137
%0 = vector.shape_cast %arg0 : vector<5x1x3x2xf32> to vector<15x2xi32>
11371138
}
11381139

11391140
// -----
11401141

11411142
func.func @shape_cast_wrong_num_elements(%arg0 : vector<5x1x3x2xf32>) {
1142-
// expected-error@+1 {{op source/result number of elements must match}}
1143+
// expected-error@+1 {{'vector.shape_cast' op has different number of elements at source (30) and result (20)}}
11431144
%0 = vector.shape_cast %arg0 : vector<5x1x3x2xf32> to vector<10x2xf32>
11441145
}
11451146

11461147
// -----
11471148

1148-
11491149
func.func @shape_cast_scalability_flag_is_dropped(%arg0 : vector<15x[2]xf32>) {
11501150
// expected-error@+1 {{different number of scalable dims at source (1) and result (0)}}
11511151
%0 = vector.shape_cast %arg0 : vector<15x[2]xf32> to vector<30xf32>

mlir/test/Dialect/Vector/ops.mlir

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -550,15 +550,13 @@ func.func @shape_cast_valid_rank_reduction(%arg0 : vector<5x1x3x2xf32>) {
550550
return
551551
}
552552

553-
554553
// CHECK-LABEL: @shape_cast_valid_rank_expansion
555554
func.func @shape_cast_valid_rank_expansion(%arg0 : vector<15x2xf32>) {
556555
// CHECK: vector.shape_cast %{{.*}} : vector<15x2xf32> to vector<5x2x3x1xf32>
557556
%0 = vector.shape_cast %arg0 : vector<15x2xf32> to vector<5x2x3x1xf32>
558557
return
559558
}
560559

561-
562560
// CHECK-LABEL: @shape_cast
563561
func.func @shape_cast(%arg0 : vector<5x1x3x2xf32>,
564562
%arg1 : vector<8x1xf32>,

0 commit comments

Comments
 (0)