Skip to content

Commit 0b96a31

Browse files
committed
[MLIR] Merge AnyVector and AnyVectorOfAnyRank type constraints.
1 parent 9d7b35d commit 0b96a31

File tree

2 files changed

+25
-27
lines changed

2 files changed

+25
-27
lines changed

mlir/include/mlir/Dialect/Vector/IR/VectorOps.td

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,7 @@ def Vector_ReductionOp :
224224
DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>
225225
]>,
226226
Arguments<(ins Vector_CombiningKindAttr:$kind,
227-
AnyVectorOfAnyRank:$vector,
227+
AnyVector:$vector,
228228
Optional<AnyType>:$acc,
229229
DefaultValuedAttr<
230230
Arith_FastMathAttr,
@@ -349,7 +349,7 @@ def Vector_BroadcastOp :
349349
PredOpTrait<"source operand and result have same element type",
350350
TCresVTEtIsSameAsOpBase<0, 0>>]>,
351351
Arguments<(ins AnyType:$source)>,
352-
Results<(outs AnyVectorOfAnyRank:$vector)> {
352+
Results<(outs AnyVector:$vector)> {
353353
let summary = "broadcast operation";
354354
let description = [{
355355
Broadcasts the scalar or k-D vector value in the source operand
@@ -528,7 +528,7 @@ def Vector_InterleaveOp :
528528
```
529529
}];
530530

531-
let arguments = (ins AnyVectorOfAnyRank:$lhs, AnyVectorOfAnyRank:$rhs);
531+
let arguments = (ins AnyVector:$lhs, AnyVector:$rhs);
532532
let results = (outs AnyVector:$result);
533533

534534
let assemblyFormat = [{
@@ -630,7 +630,7 @@ def Vector_ExtractElementOp :
630630
TypesMatchWith<"result type matches element type of vector operand",
631631
"vector", "result",
632632
"::llvm::cast<VectorType>($_self).getElementType()">]>,
633-
Arguments<(ins AnyVectorOfAnyRank:$vector,
633+
Arguments<(ins AnyVector:$vector,
634634
Optional<AnySignlessIntegerOrIndex>:$position)>,
635635
Results<(outs AnyType:$result)> {
636636
let summary = "extractelement operation";
@@ -697,7 +697,7 @@ def Vector_ExtractOp :
697697
}];
698698

699699
let arguments = (ins
700-
AnyVectorOfAnyRank:$vector,
700+
AnyVector:$vector,
701701
Variadic<Index>:$dynamic_position,
702702
DenseI64ArrayAttr:$static_position
703703
);
@@ -803,7 +803,7 @@ def Vector_FromElementsOp : Vector_Op<"from_elements", [
803803
}];
804804

805805
let arguments = (ins Variadic<AnyType>:$elements);
806-
let results = (outs AnyVectorOfAnyRank:$result);
806+
let results = (outs AnyVector:$result);
807807
let assemblyFormat = "$elements attr-dict `:` type($result)";
808808
let hasCanonicalizer = 1;
809809
}
@@ -814,9 +814,9 @@ def Vector_InsertElementOp :
814814
"result", "source",
815815
"::llvm::cast<VectorType>($_self).getElementType()">,
816816
AllTypesMatch<["dest", "result"]>]>,
817-
Arguments<(ins AnyType:$source, AnyVectorOfAnyRank:$dest,
817+
Arguments<(ins AnyType:$source, AnyVector:$dest,
818818
Optional<AnySignlessIntegerOrIndex>:$position)>,
819-
Results<(outs AnyVectorOfAnyRank:$result)> {
819+
Results<(outs AnyVector:$result)> {
820820
let summary = "insertelement operation";
821821
let description = [{
822822
Takes a scalar source, a 0-D or 1-D destination vector and a dynamic index
@@ -884,11 +884,11 @@ def Vector_InsertOp :
884884

885885
let arguments = (ins
886886
AnyType:$source,
887-
AnyVectorOfAnyRank:$dest,
887+
AnyVector:$dest,
888888
Variadic<Index>:$dynamic_position,
889889
DenseI64ArrayAttr:$static_position
890890
);
891-
let results = (outs AnyVectorOfAnyRank:$result);
891+
let results = (outs AnyVector:$result);
892892

893893
let builders = [
894894
OpBuilder<(ins "Value":$source, "Value":$dest, "int64_t":$position)>,
@@ -1250,7 +1250,7 @@ def Vector_TransferReadOp :
12501250
AnyType:$padding,
12511251
Optional<VectorOf<[I1]>>:$mask,
12521252
BoolArrayAttr:$in_bounds)>,
1253-
Results<(outs AnyVectorOfAnyRank:$vector)> {
1253+
Results<(outs AnyVector:$vector)> {
12541254

12551255
let summary = "Reads a supervector from memory into an SSA vector value.";
12561256

@@ -1492,7 +1492,7 @@ def Vector_TransferWriteOp :
14921492
AttrSizedOperandSegments,
14931493
DestinationStyleOpInterface
14941494
]>,
1495-
Arguments<(ins AnyVectorOfAnyRank:$vector,
1495+
Arguments<(ins AnyVector:$vector,
14961496
AnyShaped:$source,
14971497
Variadic<Index>:$indices,
14981498
AffineMapAttr:$permutation_map,
@@ -1710,7 +1710,7 @@ def Vector_LoadOp : Vector_Op<"load"> {
17101710
[MemRead]>:$base,
17111711
Variadic<Index>:$indices,
17121712
DefaultValuedOptionalAttr<BoolAttr, "false">:$nontemporal);
1713-
let results = (outs AnyVectorOfAnyRank:$result);
1713+
let results = (outs AnyVector:$result);
17141714

17151715
let extraClassDeclaration = [{
17161716
MemRefType getMemRefType() {
@@ -1791,7 +1791,7 @@ def Vector_StoreOp : Vector_Op<"store"> {
17911791
}];
17921792

17931793
let arguments = (ins
1794-
AnyVectorOfAnyRank:$valueToStore,
1794+
AnyVector:$valueToStore,
17951795
Arg<AnyMemRef, "the reference to store to",
17961796
[MemWrite]>:$base,
17971797
Variadic<Index>:$indices,
@@ -2199,8 +2199,8 @@ def Vector_CompressStoreOp :
21992199

22002200
def Vector_ShapeCastOp :
22012201
Vector_Op<"shape_cast", [Pure]>,
2202-
Arguments<(ins AnyVectorOfAnyRank:$source)>,
2203-
Results<(outs AnyVectorOfAnyRank:$result)> {
2202+
Arguments<(ins AnyVector:$source)>,
2203+
Results<(outs AnyVector:$result)> {
22042204
let summary = "shape_cast casts between vector shapes";
22052205
let description = [{
22062206
The shape_cast operation casts between an n-D source vector shape and
@@ -2251,8 +2251,8 @@ def Vector_ShapeCastOp :
22512251

22522252
def Vector_BitCastOp :
22532253
Vector_Op<"bitcast", [Pure, AllRanksMatch<["source", "result"]>]>,
2254-
Arguments<(ins AnyVectorOfAnyRank:$source)>,
2255-
Results<(outs AnyVectorOfAnyRank:$result)>{
2254+
Arguments<(ins AnyVector:$source)>,
2255+
Results<(outs AnyVector:$result)>{
22562256
let summary = "bitcast casts between vectors";
22572257
let description = [{
22582258
The bitcast operation casts between vectors of the same rank, the minor 1-D
@@ -2561,9 +2561,9 @@ def Vector_TransposeOp :
25612561
```
25622562
}];
25632563

2564-
let arguments = (ins AnyVectorOfAnyRank:$vector,
2564+
let arguments = (ins AnyVector:$vector,
25652565
DenseI64ArrayAttr:$permutation);
2566-
let results = (outs AnyVectorOfAnyRank:$result);
2566+
let results = (outs AnyVector:$result);
25672567

25682568
let builders = [
25692569
OpBuilder<(ins "Value":$vector, "ArrayRef<int64_t>":$permutation)>
@@ -2593,7 +2593,7 @@ def Vector_PrintOp :
25932593
>,
25942594
]>,
25952595
Arguments<(ins Optional<Type<Or<[
2596-
AnyVectorOfAnyRank.predicate,
2596+
AnyVector.predicate,
25972597
AnyInteger.predicate, Index.predicate, AnyFloat.predicate
25982598
]>>>:$source, DefaultValuedAttr<Vector_PrintPunctuation,
25992599
"::mlir::vector::PrintPunctuation::NewLine">:$punctuation,
@@ -2814,7 +2814,7 @@ def Vector_SplatOp : Vector_Op<"splat", [
28142814

28152815
let arguments = (ins AnyTypeOf<[AnySignlessInteger, Index, AnyFloat],
28162816
"integer/index/float type">:$input);
2817-
let results = (outs AnyVectorOfAnyRank:$aggregate);
2817+
let results = (outs AnyVector:$aggregate);
28182818

28192819
let builders = [
28202820
OpBuilder<(ins "Value":$element, "Type":$aggregateType),
@@ -2873,11 +2873,11 @@ def Vector_ScanOp :
28732873
AllTypesMatch<["initial_value", "accumulated_value"]> ]>,
28742874
Arguments<(ins Vector_CombiningKindAttr:$kind,
28752875
AnyVector:$source,
2876-
AnyVectorOfAnyRank:$initial_value,
2876+
AnyVector:$initial_value,
28772877
I64Attr:$reduction_dim,
28782878
BoolAttr:$inclusive)>,
28792879
Results<(outs AnyVector:$dest,
2880-
AnyVectorOfAnyRank:$accumulated_value)> {
2880+
AnyVector:$accumulated_value)> {
28812881
let summary = "Scan operation";
28822882
let description = [{
28832883
Performs an inclusive/exclusive scan on an n-D vector along a single

mlir/include/mlir/IR/CommonTypeConstraints.td

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -657,9 +657,7 @@ class VectorWithTrailingDimScalableOfSizeAndType<list<int> allowedTrailingSizes,
657657
ShapedTypeWithNthDimOfSize<-1, allowedTrailingSizes>.summary,
658658
"::mlir::VectorType">;
659659

660-
def AnyVector : VectorOf<[AnyType]>;
661-
// Temporary vector type clone that allows gradual transition to 0-D vectors.
662-
def AnyVectorOfAnyRank : VectorOfAnyRankOf<[AnyType]>;
660+
def AnyVector : VectorOfAnyRankOf<[AnyType]>;
663661

664662
def AnyFixedVector : FixedVectorOf<[AnyType]>;
665663

0 commit comments

Comments
 (0)