From 1d6096bf42e0b656631eb7a34acad0ff6d046815 Mon Sep 17 00:00:00 2001 From: Andrzej Warzynski Date: Mon, 14 Apr 2025 13:20:17 +0200 Subject: [PATCH] [mlir][vector] Tighten the semantics of vector.gather MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This patch restricts `vector.gather` to only accept tensors and memrefs as valid sources. Currently, the source is typed as `AnyShaped`, which also includes vectors—allowing the following (invalid) construct to pass verification: ```mlir %0 = vector.gather %base[%c0][%indices], %mask, %pass_thru : vector<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32> ``` (Note: the source %base here is a vector, which is incorrect.) In contrast, `vector.scatter` currently only accepts memrefs, so some asymmetry remains between the two ops. This PR is a step toward aligning their semantics. --- .../mlir/Dialect/Vector/IR/VectorOps.td | 2 +- mlir/include/mlir/IR/CommonTypeConstraints.td | 15 ++++++++++++- mlir/test/Dialect/Vector/invalid.mlir | 21 +++++++++++++++++++ 3 files changed, 36 insertions(+), 2 deletions(-) diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td index 134472cefbf4e..8ae5961af41bb 100644 --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -1971,7 +1971,7 @@ def Vector_GatherOp : DeclareOpInterfaceMethods, DeclareOpInterfaceMethods ]>, - Arguments<(ins Arg:$base, + Arguments<(ins Arg, "", [MemRead]>:$base, Variadic:$indices, VectorOfNonZeroRankOf<[AnyInteger, Index]>:$index_vec, VectorOfNonZeroRankOf<[I1]>:$mask, diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td index e6f17ded4628b..45ec1846580f2 100644 --- a/mlir/include/mlir/IR/CommonTypeConstraints.td +++ b/mlir/include/mlir/IR/CommonTypeConstraints.td @@ -63,6 +63,9 @@ def IsTensorTypePred : CPred<"::llvm::isa<::mlir::TensorType>($_self)">; // Whether a type is a MemRefType. def IsMemRefTypePred : CPred<"::llvm::isa<::mlir::MemRefType>($_self)">; +// Whether a type is a TensorType or a MemRefType. +def IsTensorOrMemRefTypePred : Or<[IsTensorTypePred, IsMemRefTypePred]>; + // Whether a type is an UnrankedMemRefType def IsUnrankedMemRefTypePred : CPred<"::llvm::isa<::mlir::UnrankedMemRefType>($_self)">; @@ -426,7 +429,9 @@ class ValueSemanticsContainerOf allowedTypes> : ShapedContainerType; +//===----------------------------------------------------------------------===// // Vector types. +//===----------------------------------------------------------------------===// class VectorOfNonZeroRankOf allowedTypes> : ShapedContainerType allowedTypes> def AnyStaticShapeTensor : StaticShapeTensorOf<[AnyType]>; //===----------------------------------------------------------------------===// -// Memref type. +// Memref types. //===----------------------------------------------------------------------===// // Any unranked memref whose element type is from the given `allowedTypes` list. @@ -878,6 +883,14 @@ class NestedTupleOf allowedTypes> : "getFlattenedTypes(::llvm::cast<::mlir::TupleType>($_self))", "nested tuple">; +//===----------------------------------------------------------------------===// +// Mixed types +//===----------------------------------------------------------------------===// + +class TensorOrMemRef allowedTypes> : + ShapedContainerType; + //===----------------------------------------------------------------------===// // Common type constraints //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir index ea6d0021391fb..c6e780c6641fd 100644 --- a/mlir/test/Dialect/Vector/invalid.mlir +++ b/mlir/test/Dialect/Vector/invalid.mlir @@ -1409,6 +1409,16 @@ func.func @maskedstore_memref_mismatch(%base: memref, %mask: vector<16xi1 // ----- +func.func @gather_from_vector(%base: vector<16xf32>, %indices: vector<16xi32>, + %mask: vector<16xi1>, %pass_thru: vector<16xf32>) { + %c0 = arith.constant 0 : index + // expected-error@+1 {{'vector.gather' op operand #0 must be Tensor or MemRef of any type values, but got 'vector<16xf32>'}} + %0 = vector.gather %base[%c0][%indices], %mask, %pass_thru + : vector<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32> +} + +// ----- + func.func @gather_base_type_mismatch(%base: memref, %indices: vector<16xi32>, %mask: vector<16xi1>, %pass_thru: vector<16xf32>) { %c0 = arith.constant 0 : index @@ -1469,6 +1479,17 @@ func.func @gather_pass_thru_type_mismatch(%base: memref, %indices: vector // ----- +func.func @scatter_to_vector(%base: vector<16xf32>, %indices: vector<16xi32>, + %mask: vector<16xi1>, %pass_thru: vector<16xf32>) { + %c0 = arith.constant 0 : index + // expected-error@+2 {{custom op 'vector.scatter' invalid kind of type specified}} + vector.scatter %base[%c0][%indices], %mask, %pass_thru + : vector<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32> +} + +// ----- + + func.func @scatter_base_type_mismatch(%base: memref, %indices: vector<16xi32>, %mask: vector<16xi1>, %value: vector<16xf32>) { %c0 = arith.constant 0 : index