diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td index c02b16ea93170..e859270cf9a5e 100644 --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -1819,17 +1819,17 @@ def Vector_MaskedLoadOp : Vector_Op<"maskedload">, Arguments<(ins Arg:$base, Variadic:$indices, - VectorOfRankAndType<[1], [I1]>:$mask, - VectorOfRank<[1]>:$pass_thru)>, - Results<(outs VectorOfRank<[1]>:$result)> { + VectorOf<[I1]>:$mask, + AnyVector:$pass_thru)>, + Results<(outs AnyVector:$result)> { let summary = "loads elements from memory into a vector as defined by a mask vector"; let description = [{ - The masked load reads elements from memory into a 1-D vector as defined - by a base with indices and a 1-D mask vector. When the mask is set, the + The masked load reads elements from memory into a vector as defined + by a base with indices and a mask vector. When the mask is set, the element is read from memory. Otherwise, the corresponding element is taken - from a 1-D pass-through vector. Informally the semantics are: + from a pass-through vector. Informally the semantics are: ``` result[0] := if mask[0] then base[i + 0] else pass_thru[0] result[1] := if mask[1] then base[i + 1] else pass_thru[1] @@ -1882,14 +1882,14 @@ def Vector_MaskedStoreOp : Vector_Op<"maskedstore">, Arguments<(ins Arg:$base, Variadic:$indices, - VectorOfRankAndType<[1], [I1]>:$mask, - VectorOfRank<[1]>:$valueToStore)> { + VectorOf<[I1]>:$mask, + AnyVector:$valueToStore)> { let summary = "stores elements from a vector into memory as defined by a mask vector"; let description = [{ - The masked store operation writes elements from a 1-D vector into memory - as defined by a base with indices and a 1-D mask vector. When the mask is + The masked store operation writes elements from a vector into memory + as defined by a base with indices and a mask vector. When the mask is set, the corresponding element from the vector is written to memory. Otherwise, no action is taken for the element. Informally the semantics are: ``` @@ -2076,23 +2076,26 @@ def Vector_ExpandLoadOp : Vector_Op<"expandload">, Arguments<(ins Arg:$base, Variadic:$indices, - VectorOfRankAndType<[1], [I1]>:$mask, - VectorOfRank<[1]>:$pass_thru)>, - Results<(outs VectorOfRank<[1]>:$result)> { + VectorOf<[I1]>:$mask, + AnyVector:$pass_thru)>, + Results<(outs AnyVector:$result)> { let summary = "reads elements from memory and spreads them into a vector as defined by a mask"; let description = [{ - The expand load reads elements from memory into a 1-D vector as defined - by a base with indices and a 1-D mask vector. When the mask is set, the - next element is read from memory. Otherwise, the corresponding element - is taken from a 1-D pass-through vector. Informally the semantics are: + The expand load reads elements from memory into a vector as defined by a + base with indices and a mask vector. Expansion only applies to the innermost + dimension. When the mask is set, the next element is read from memory. + Otherwise, the corresponding element is taken from a pass-through vector. + Informally the semantics are: + ``` index = i result[0] := if mask[0] then base[index++] else pass_thru[0] result[1] := if mask[1] then base[index++] else pass_thru[1] etc. ``` + Note that the index increment is done conditionally. If a mask bit is set and the corresponding index is out-of-bounds for the @@ -2140,22 +2143,25 @@ def Vector_CompressStoreOp : Vector_Op<"compressstore">, Arguments<(ins Arg:$base, Variadic:$indices, - VectorOfRankAndType<[1], [I1]>:$mask, - VectorOfRank<[1]>:$valueToStore)> { + VectorOf<[I1]>:$mask, + AnyVector:$valueToStore)> { let summary = "writes elements selectively from a vector as defined by a mask"; let description = [{ - The compress store operation writes elements from a 1-D vector into memory - as defined by a base with indices and a 1-D mask vector. When the mask is - set, the corresponding element from the vector is written next to memory. - Otherwise, no action is taken for the element. Informally the semantics are: + The compress store operation writes elements from a vector into memory as + defined by a base with indices and a mask vector. Compression only applies + to the innermost dimension. When the mask is set, the corresponding element + from the vector is written next to memory. Otherwise, no action is taken + for the element. Informally the semantics are: + ``` index = i if (mask[0]) base[index++] = value[0] if (mask[1]) base[index++] = value[1] etc. ``` + Note that the index increment is done conditionally. If a mask bit is set and the corresponding index is out-of-bounds for the diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index a2abe1619454f..d71a236f62f45 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -4977,8 +4977,8 @@ LogicalResult MaskedLoadOp::verify() { return emitOpError("base and result element type should match"); if (llvm::size(getIndices()) != memType.getRank()) return emitOpError("requires ") << memType.getRank() << " indices"; - if (resVType.getDimSize(0) != maskVType.getDimSize(0)) - return emitOpError("expected result dim to match mask dim"); + if (resVType.getShape() != maskVType.getShape()) + return emitOpError("expected result shape to match mask shape"); if (resVType != passVType) return emitOpError("expected pass_thru of same type as result type"); return success(); @@ -5030,8 +5030,8 @@ LogicalResult MaskedStoreOp::verify() { return emitOpError("base and valueToStore element type should match"); if (llvm::size(getIndices()) != memType.getRank()) return emitOpError("requires ") << memType.getRank() << " indices"; - if (valueVType.getDimSize(0) != maskVType.getDimSize(0)) - return emitOpError("expected valueToStore dim to match mask dim"); + if (valueVType.getShape() != maskVType.getShape()) + return emitOpError("expected valueToStore shape to match mask shape"); return success(); } diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir index 36d04bb77e3b9..5b0fb537b3565 100644 --- a/mlir/test/Dialect/Vector/invalid.mlir +++ b/mlir/test/Dialect/Vector/invalid.mlir @@ -1356,7 +1356,7 @@ func.func @maskedload_base_type_mismatch(%base: memref, %mask: vector<16x func.func @maskedload_dim_mask_mismatch(%base: memref, %mask: vector<15xi1>, %pass: vector<16xf32>) { %c0 = arith.constant 0 : index - // expected-error@+1 {{'vector.maskedload' op expected result dim to match mask dim}} + // expected-error@+1 {{'vector.maskedload' op expected result shape to match mask shape}} %0 = vector.maskedload %base[%c0], %mask, %pass : memref, vector<15xi1>, vector<16xf32> into vector<16xf32> } @@ -1387,7 +1387,7 @@ func.func @maskedstore_base_type_mismatch(%base: memref, %mask: vector<16 func.func @maskedstore_dim_mask_mismatch(%base: memref, %mask: vector<15xi1>, %value: vector<16xf32>) { %c0 = arith.constant 0 : index - // expected-error@+1 {{'vector.maskedstore' op expected valueToStore dim to match mask dim}} + // expected-error@+1 {{'vector.maskedstore' op expected valueToStore shape to match mask shape}} vector.maskedstore %base[%c0], %mask, %value : memref, vector<15xi1>, vector<16xf32> }