Skip to content

Commit bb00f5b

Browse files
authored
[mlir][vector] Remove unneeded mask restriction (#113742)
These were added when the only mapping was to LLVM.
1 parent 5aa741d commit bb00f5b

File tree

3 files changed

+35
-29
lines changed

3 files changed

+35
-29
lines changed

mlir/include/mlir/Dialect/Vector/IR/VectorOps.td

Lines changed: 29 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1819,17 +1819,17 @@ def Vector_MaskedLoadOp :
18191819
Vector_Op<"maskedload">,
18201820
Arguments<(ins Arg<AnyMemRef, "", [MemRead]>:$base,
18211821
Variadic<Index>:$indices,
1822-
VectorOfRankAndType<[1], [I1]>:$mask,
1823-
VectorOfRank<[1]>:$pass_thru)>,
1824-
Results<(outs VectorOfRank<[1]>:$result)> {
1822+
VectorOf<[I1]>:$mask,
1823+
AnyVector:$pass_thru)>,
1824+
Results<(outs AnyVector:$result)> {
18251825

18261826
let summary = "loads elements from memory into a vector as defined by a mask vector";
18271827

18281828
let description = [{
1829-
The masked load reads elements from memory into a 1-D vector as defined
1830-
by a base with indices and a 1-D mask vector. When the mask is set, the
1829+
The masked load reads elements from memory into a vector as defined
1830+
by a base with indices and a mask vector. When the mask is set, the
18311831
element is read from memory. Otherwise, the corresponding element is taken
1832-
from a 1-D pass-through vector. Informally the semantics are:
1832+
from a pass-through vector. Informally the semantics are:
18331833
```
18341834
result[0] := if mask[0] then base[i + 0] else pass_thru[0]
18351835
result[1] := if mask[1] then base[i + 1] else pass_thru[1]
@@ -1882,14 +1882,14 @@ def Vector_MaskedStoreOp :
18821882
Vector_Op<"maskedstore">,
18831883
Arguments<(ins Arg<AnyMemRef, "", [MemWrite]>:$base,
18841884
Variadic<Index>:$indices,
1885-
VectorOfRankAndType<[1], [I1]>:$mask,
1886-
VectorOfRank<[1]>:$valueToStore)> {
1885+
VectorOf<[I1]>:$mask,
1886+
AnyVector:$valueToStore)> {
18871887

18881888
let summary = "stores elements from a vector into memory as defined by a mask vector";
18891889

18901890
let description = [{
1891-
The masked store operation writes elements from a 1-D vector into memory
1892-
as defined by a base with indices and a 1-D mask vector. When the mask is
1891+
The masked store operation writes elements from a vector into memory
1892+
as defined by a base with indices and a mask vector. When the mask is
18931893
set, the corresponding element from the vector is written to memory. Otherwise,
18941894
no action is taken for the element. Informally the semantics are:
18951895
```
@@ -2076,23 +2076,26 @@ def Vector_ExpandLoadOp :
20762076
Vector_Op<"expandload">,
20772077
Arguments<(ins Arg<AnyMemRef, "", [MemRead]>:$base,
20782078
Variadic<Index>:$indices,
2079-
VectorOfRankAndType<[1], [I1]>:$mask,
2080-
VectorOfRank<[1]>:$pass_thru)>,
2081-
Results<(outs VectorOfRank<[1]>:$result)> {
2079+
VectorOf<[I1]>:$mask,
2080+
AnyVector:$pass_thru)>,
2081+
Results<(outs AnyVector:$result)> {
20822082

20832083
let summary = "reads elements from memory and spreads them into a vector as defined by a mask";
20842084

20852085
let description = [{
2086-
The expand load reads elements from memory into a 1-D vector as defined
2087-
by a base with indices and a 1-D mask vector. When the mask is set, the
2088-
next element is read from memory. Otherwise, the corresponding element
2089-
is taken from a 1-D pass-through vector. Informally the semantics are:
2086+
The expand load reads elements from memory into a vector as defined by a
2087+
base with indices and a mask vector. Expansion only applies to the innermost
2088+
dimension. When the mask is set, the next element is read from memory.
2089+
Otherwise, the corresponding element is taken from a pass-through vector.
2090+
Informally the semantics are:
2091+
20902092
```
20912093
index = i
20922094
result[0] := if mask[0] then base[index++] else pass_thru[0]
20932095
result[1] := if mask[1] then base[index++] else pass_thru[1]
20942096
etc.
20952097
```
2098+
20962099
Note that the index increment is done conditionally.
20972100

20982101
If a mask bit is set and the corresponding index is out-of-bounds for the
@@ -2140,22 +2143,25 @@ def Vector_CompressStoreOp :
21402143
Vector_Op<"compressstore">,
21412144
Arguments<(ins Arg<AnyMemRef, "", [MemWrite]>:$base,
21422145
Variadic<Index>:$indices,
2143-
VectorOfRankAndType<[1], [I1]>:$mask,
2144-
VectorOfRank<[1]>:$valueToStore)> {
2146+
VectorOf<[I1]>:$mask,
2147+
AnyVector:$valueToStore)> {
21452148

21462149
let summary = "writes elements selectively from a vector as defined by a mask";
21472150

21482151
let description = [{
2149-
The compress store operation writes elements from a 1-D vector into memory
2150-
as defined by a base with indices and a 1-D mask vector. When the mask is
2151-
set, the corresponding element from the vector is written next to memory.
2152-
Otherwise, no action is taken for the element. Informally the semantics are:
2152+
The compress store operation writes elements from a vector into memory as
2153+
defined by a base with indices and a mask vector. Compression only applies
2154+
to the innermost dimension. When the mask is set, the corresponding element
2155+
from the vector is written next to memory. Otherwise, no action is taken
2156+
for the element. Informally the semantics are:
2157+
21532158
```
21542159
index = i
21552160
if (mask[0]) base[index++] = value[0]
21562161
if (mask[1]) base[index++] = value[1]
21572162
etc.
21582163
```
2164+
21592165
Note that the index increment is done conditionally.
21602166

21612167
If a mask bit is set and the corresponding index is out-of-bounds for the

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4977,8 +4977,8 @@ LogicalResult MaskedLoadOp::verify() {
49774977
return emitOpError("base and result element type should match");
49784978
if (llvm::size(getIndices()) != memType.getRank())
49794979
return emitOpError("requires ") << memType.getRank() << " indices";
4980-
if (resVType.getDimSize(0) != maskVType.getDimSize(0))
4981-
return emitOpError("expected result dim to match mask dim");
4980+
if (resVType.getShape() != maskVType.getShape())
4981+
return emitOpError("expected result shape to match mask shape");
49824982
if (resVType != passVType)
49834983
return emitOpError("expected pass_thru of same type as result type");
49844984
return success();
@@ -5030,8 +5030,8 @@ LogicalResult MaskedStoreOp::verify() {
50305030
return emitOpError("base and valueToStore element type should match");
50315031
if (llvm::size(getIndices()) != memType.getRank())
50325032
return emitOpError("requires ") << memType.getRank() << " indices";
5033-
if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
5034-
return emitOpError("expected valueToStore dim to match mask dim");
5033+
if (valueVType.getShape() != maskVType.getShape())
5034+
return emitOpError("expected valueToStore shape to match mask shape");
50355035
return success();
50365036
}
50375037

mlir/test/Dialect/Vector/invalid.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1356,7 +1356,7 @@ func.func @maskedload_base_type_mismatch(%base: memref<?xf64>, %mask: vector<16x
13561356

13571357
func.func @maskedload_dim_mask_mismatch(%base: memref<?xf32>, %mask: vector<15xi1>, %pass: vector<16xf32>) {
13581358
%c0 = arith.constant 0 : index
1359-
// expected-error@+1 {{'vector.maskedload' op expected result dim to match mask dim}}
1359+
// expected-error@+1 {{'vector.maskedload' op expected result shape to match mask shape}}
13601360
%0 = vector.maskedload %base[%c0], %mask, %pass : memref<?xf32>, vector<15xi1>, vector<16xf32> into vector<16xf32>
13611361
}
13621362

@@ -1387,7 +1387,7 @@ func.func @maskedstore_base_type_mismatch(%base: memref<?xf64>, %mask: vector<16
13871387

13881388
func.func @maskedstore_dim_mask_mismatch(%base: memref<?xf32>, %mask: vector<15xi1>, %value: vector<16xf32>) {
13891389
%c0 = arith.constant 0 : index
1390-
// expected-error@+1 {{'vector.maskedstore' op expected valueToStore dim to match mask dim}}
1390+
// expected-error@+1 {{'vector.maskedstore' op expected valueToStore shape to match mask shape}}
13911391
vector.maskedstore %base[%c0], %mask, %value : memref<?xf32>, vector<15xi1>, vector<16xf32>
13921392
}
13931393

0 commit comments

Comments
 (0)