Skip to content

Commit db1b717

Browse files
committed
tighten
1 parent 7bfc219 commit db1b717

File tree

4 files changed

+81
-59
lines changed

4 files changed

+81
-59
lines changed

mlir/include/mlir/Dialect/Vector/IR/VectorOps.td

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2248,10 +2248,11 @@ def Vector_ShapeCastOp :
22482248
vector. The element type remains the same, as does the number of elements
22492249
(product of dimensions).
22502250

2251-
If reducing or preserving rank (n >= k), all result dimension sizes must be
2252-
products of contiguous source dimension sizes. If expanding rank (n < k),
2253-
source dimensions must all factor into contiguous sequences of destination
2254-
dimension sizes.
2251+
A shape_cast must be either collapsing or expanding. Collapsing means all
2252+
result dimension sizes are products of contiguous source dimension sizes.
2253+
Expanding means source dimensions all factor into contiguous sequences of
2254+
destination dimension sizes. Size 1 dimensions in source and destination
2255+
are ignored.
22552256

22562257
Each source dim is expanded (or contiguous sequence of source dims combined)
22572258
in source dimension list order (i.e. 0 <= i < n), to produce a contiguous

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 47 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -5532,41 +5532,34 @@ void ShapeCastOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
55325532
setResultRanges(getResult(), argRanges.front());
55335533
}
55345534

5535-
/// Returns true if each element of 'a' is equal to the product of a contiguous
5536-
/// sequence of the elements of 'b'. Returns false otherwise.
5535+
/// Returns true if each element of 'a' is either 1 or equal to the product of a
5536+
/// contiguous sequence of the elements of 'b'. Returns false otherwise.
5537+
///
5538+
/// This function assumes that the product of elements in a and b are the same.
55375539
static bool isExpandingShapeCast(ArrayRef<int64_t> a, ArrayRef<int64_t> b) {
5538-
unsigned rankA = a.size();
5539-
unsigned rankB = b.size();
5540-
if (rankA > rankB) {
5541-
return false;
5542-
}
5543-
5544-
auto isOne = [](int64_t v) { return v == 1; };
5545-
5546-
// Special-case for n-D to 0-d shape cast. 'b' must be all ones to be shape
5547-
// casted to a 0-d vector.
5548-
if (rankA == 0 && llvm::all_of(b, isOne))
5549-
return true;
55505540

5541+
unsigned rankA = a.size();
55515542
unsigned i = 0;
55525543
unsigned j = 0;
5553-
while (i < rankA && j < rankB) {
5544+
while (i < rankA) {
5545+
if (a[i] == 1) {
5546+
++i;
5547+
continue;
5548+
}
5549+
55545550
int64_t dimA = a[i];
55555551
int64_t dimB = 1;
5556-
while (dimB < dimA && j < rankB)
5552+
5553+
while (dimB < dimA) {
55575554
dimB *= b[j++];
5558-
if (dimA != dimB)
5559-
break;
5560-
++i;
5555+
}
55615556

5562-
// Handle the case when trailing dimensions are of size 1.
5563-
// Include them into the contiguous sequence.
5564-
if (i < rankA && llvm::all_of(a.slice(i), isOne))
5565-
i = rankA;
5566-
if (j < rankB && llvm::all_of(b.slice(j), isOne))
5567-
j = rankB;
5557+
if (dimA != dimB) {
5558+
return false;
5559+
}
5560+
++i;
55685561
}
5569-
return i == rankA && j == rankB;
5562+
return true;
55705563
}
55715564

55725565
static bool isValidShapeCast(ArrayRef<int64_t> a, ArrayRef<int64_t> b) {
@@ -5582,7 +5575,8 @@ static LogicalResult verifyVectorShapeCast(Operation *op,
55825575
ArrayRef<int64_t> inShape = sourceVectorType.getShape();
55835576
ArrayRef<int64_t> outShape = resultVectorType.getShape();
55845577

5585-
// Check that product of source dim sizes matches product of result dim sizes.
5578+
// Check that product of source dim sizes matches product of result dim
5579+
// sizes.
55865580
int64_t nInElms = std::accumulate(inShape.begin(), inShape.end(), 1LL,
55875581
std::multiplies<int64_t>{});
55885582
int64_t nOutElms = std::accumulate(outShape.begin(), outShape.end(), 1LL,
@@ -5702,8 +5696,8 @@ static VectorType trimTrailingOneDims(VectorType oldType) {
57025696
///
57035697
/// Looks at `vector.shape_cast` Ops that simply "drop" the trailing unit
57045698
/// dimension. If the input vector comes from `vector.create_mask` for which
5705-
/// the corresponding mask input value is 1 (e.g. `%c1` below), then it is safe
5706-
/// to fold shape_cast into create_mask.
5699+
/// the corresponding mask input value is 1 (e.g. `%c1` below), then it is
5700+
/// safe to fold shape_cast into create_mask.
57075701
///
57085702
/// BEFORE:
57095703
/// %1 = vector.create_mask %c1, %dim, %c1, %c1 : vector<1x[4]x1x1xi1>
@@ -5970,8 +5964,8 @@ LogicalResult TypeCastOp::verify() {
59705964
auto resultType = getResultMemRefType();
59715965
if (getElementTypeOrSelf(getElementTypeOrSelf(sourceType)) !=
59725966
getElementTypeOrSelf(getElementTypeOrSelf(resultType)))
5973-
return emitOpError(
5974-
"expects result and operand with same underlying scalar type: ")
5967+
return emitOpError("expects result and operand with same underlying "
5968+
"scalar type: ")
59755969
<< resultType;
59765970
if (extractShape(sourceType) != extractShape(resultType))
59775971
return emitOpError(
@@ -6009,7 +6003,8 @@ OpFoldResult vector::TransposeOp::fold(FoldAdaptor adaptor) {
60096003
return attr.reshape(getResultVectorType());
60106004

60116005
// Eliminate identity transpose ops. This happens when the dimensions of the
6012-
// input vector remain in their original order after the transpose operation.
6006+
// input vector remain in their original order after the transpose
6007+
// operation.
60136008
ArrayRef<int64_t> perm = getPermutation();
60146009

60156010
// Check if the permutation of the dimensions contains sequential values:
@@ -6068,7 +6063,8 @@ class TransposeFolder final : public OpRewritePattern<vector::TransposeOp> {
60686063
return result;
60696064
};
60706065

6071-
// Return if the input of 'transposeOp' is not defined by another transpose.
6066+
// Return if the input of 'transposeOp' is not defined by another
6067+
// transpose.
60726068
vector::TransposeOp parentTransposeOp =
60736069
transposeOp.getVector().getDefiningOp<vector::TransposeOp>();
60746070
if (!parentTransposeOp)
@@ -6212,8 +6208,9 @@ LogicalResult ConstantMaskOp::verify() {
62126208
return emitOpError(
62136209
"only supports 'none set' or 'all set' scalable dimensions");
62146210
}
6215-
// Verify that if one mask dim size is zero, they all should be zero (because
6216-
// the mask region is a conjunction of each mask dimension interval).
6211+
// Verify that if one mask dim size is zero, they all should be zero
6212+
// (because the mask region is a conjunction of each mask dimension
6213+
// interval).
62176214
bool anyZeros = llvm::is_contained(maskDimSizes, 0);
62186215
bool allZeros = llvm::all_of(maskDimSizes, [](int64_t s) { return s == 0; });
62196216
if (anyZeros && !allZeros)
@@ -6251,7 +6248,8 @@ void CreateMaskOp::build(OpBuilder &builder, OperationState &result,
62516248

62526249
LogicalResult CreateMaskOp::verify() {
62536250
auto vectorType = llvm::cast<VectorType>(getResult().getType());
6254-
// Verify that an operand was specified for each result vector each dimension.
6251+
// Verify that an operand was specified for each result vector each
6252+
// dimension.
62556253
if (vectorType.getRank() == 0) {
62566254
if (getNumOperands() != 1)
62576255
return emitOpError(
@@ -6458,8 +6456,8 @@ void mlir::vector::MaskOp::print(OpAsmPrinter &p) {
64586456
void MaskOp::ensureTerminator(Region &region, Builder &builder, Location loc) {
64596457
OpTrait::SingleBlockImplicitTerminator<vector::YieldOp>::Impl<
64606458
MaskOp>::ensureTerminator(region, builder, loc);
6461-
// Keep the default yield terminator if the number of masked operations is not
6462-
// the expected. This case will trigger a verification failure.
6459+
// Keep the default yield terminator if the number of masked operations is
6460+
// not the expected. This case will trigger a verification failure.
64636461
Block &block = region.front();
64646462
if (block.getOperations().size() != 2)
64656463
return;
@@ -6563,9 +6561,9 @@ LogicalResult MaskOp::fold(FoldAdaptor adaptor,
65636561
return success();
65646562
}
65656563

6566-
// Elides empty vector.mask operations with or without return values. Propagates
6567-
// the yielded values by the vector.yield terminator, if any, or erases the op,
6568-
// otherwise.
6564+
// Elides empty vector.mask operations with or without return values.
6565+
// Propagates the yielded values by the vector.yield terminator, if any, or
6566+
// erases the op, otherwise.
65696567
class ElideEmptyMaskOp : public OpRewritePattern<MaskOp> {
65706568
using OpRewritePattern::OpRewritePattern;
65716569

@@ -6668,7 +6666,8 @@ OpFoldResult SplatOp::fold(FoldAdaptor adaptor) {
66686666
if (!isa_and_nonnull<IntegerAttr, FloatAttr>(constOperand))
66696667
return {};
66706668

6671-
// SplatElementsAttr::get treats single value for second arg as being a splat.
6669+
// SplatElementsAttr::get treats single value for second arg as being a
6670+
// splat.
66726671
return SplatElementsAttr::get(getType(), {constOperand});
66736672
}
66746673

@@ -6790,12 +6789,12 @@ Operation *mlir::vector::maskOperation(OpBuilder &builder,
67906789
}
67916790

67926791
/// Creates a vector select operation that picks values from `newValue` or
6793-
/// `passthru` for each result vector lane based on `mask`. This utility is used
6794-
/// to propagate the pass-thru value of vector.mask or for cases where only the
6795-
/// pass-thru value propagation is needed. VP intrinsics do not support
6796-
/// pass-thru values and every mask-out lane is set to poison. LLVM backends are
6797-
/// usually able to match op + select patterns and fold them into a native
6798-
/// target instructions.
6792+
/// `passthru` for each result vector lane based on `mask`. This utility is
6793+
/// used to propagate the pass-thru value of vector.mask or for cases where
6794+
/// only the pass-thru value propagation is needed. VP intrinsics do not
6795+
/// support pass-thru values and every mask-out lane is set to poison. LLVM
6796+
/// backends are usually able to match op + select patterns and fold them into
6797+
/// a native target instructions.
67996798
Value mlir::vector::selectPassthru(OpBuilder &builder, Value mask,
68006799
Value newValue, Value passthru) {
68016800
if (!mask)

mlir/test/Dialect/Vector/canonicalize.mlir

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -950,14 +950,16 @@ func.func @insert_no_fold_scalar_to_0d(%v: vector<f32>) -> vector<f32> {
950950

951951
// -----
952952

953+
// The definition of shape_cast stipulates that it must be either expanding or collapsing,
954+
// it cannot be a mixture of both.
953955
// CHECK-LABEL: dont_fold_expand_collapse
954-
// CHECK: %[[A:.*]] = vector.shape_cast %{{.*}} : vector<1x1x64xf32> to vector<1x1x8x8xf32>
955-
// CHECK: %[[B:.*]] = vector.shape_cast %{{.*}} : vector<1x1x8x8xf32> to vector<8x8xf32>
956-
// CHECK: return %[[B]] : vector<8x8xf32>
957-
func.func @dont_fold_expand_collapse(%arg0: vector<1x1x64xf32>) -> vector<8x8xf32> {
958-
%0 = vector.shape_cast %arg0 : vector<1x1x64xf32> to vector<1x1x8x8xf32>
959-
%1 = vector.shape_cast %0 : vector<1x1x8x8xf32> to vector<8x8xf32>
960-
return %1 : vector<8x8xf32>
956+
// CHECK: %[[A:.*]] = vector.shape_cast %{{.*}} : vector<2x2x9xf32> to vector<2x2x3x3xf32>
957+
// CHECK: %[[B:.*]] = vector.shape_cast %{{.*}} : vector<2x2x3x3xf32> to vector<4x3x3xf32>
958+
// CHECK: return %[[B]] : vector<4x3x3xf32>
959+
func.func @dont_fold_expand_collapse(%arg0: vector<2x2x9xf32>) -> vector<4x3x3xf32> {
960+
%0 = vector.shape_cast %arg0 : vector<2x2x9xf32> to vector<2x2x3x3xf32>
961+
%1 = vector.shape_cast %0 : vector<2x2x3x3xf32> to vector<4x3x3xf32>
962+
return %1 : vector<4x3x3xf32>
961963
}
962964

963965
// -----

mlir/test/Dialect/Vector/ops.mlir

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -581,9 +581,29 @@ func.func @shape_cast_rank_preserving(%arg0 : vector<1x4xf32>) -> vector<4x1xf32
581581

582582
// CHECK: vector.shape_cast %{{.*}} : vector<1x4xf32> to vector<4x1xf32>
583583
%0 = vector.shape_cast %arg0 : vector<1x4xf32> to vector<4x1xf32>
584+
584585
return %0 : vector<4x1xf32>
585586
}
586587

588+
589+
// CHECK-LABEL: @collapse_but_increase_rank
590+
func.func @collapse_but_increase_rank(%arg0 : vector<2x3x5x7xf32>) -> vector<1x6x1x35x1xf32> {
591+
592+
// CHECK: vector.shape_cast %{{.*}} : vector<2x3x5x7xf32> to vector<1x6x1x35x1xf32>
593+
%0 = vector.shape_cast %arg0 : vector<2x3x5x7xf32> to vector<1x6x1x35x1xf32>
594+
595+
return %0 : vector<1x6x1x35x1xf32>
596+
}
597+
598+
// CHECK-LABEL: @expand_but_decrease_rank
599+
func.func @expand_but_decrease_rank(%arg0 : vector<1x1x6xi8>) -> vector<2x3xi8> {
600+
601+
// CHECK: vector.shape_cast %{{.*}} : vector<1x1x6xi8> to vector<2x3xi8>
602+
%0 = vector.shape_cast %arg0 : vector<1x1x6xi8> to vector<2x3xi8>
603+
604+
return %0 : vector<2x3xi8>
605+
}
606+
587607
// CHECK-LABEL: @bitcast
588608
func.func @bitcast(%arg0 : vector<5x1x3x2xf32>,
589609
%arg1 : vector<8x1xi32>,

0 commit comments

Comments
 (0)