diff --git a/mlir/include/mlir/IR/VectorTypes.h b/mlir/include/mlir/IR/VectorTypes.h new file mode 100644 index 0000000000000..c209f869a579d --- /dev/null +++ b/mlir/include/mlir/IR/VectorTypes.h @@ -0,0 +1,51 @@ +//===- VectorTypes.h - MLIR Vector Types ------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Convenience wrappers for `VectorType` to allow idiomatic code like +// * isa(type) +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_IR_VECTORTYPES_H +#define MLIR_IR_VECTORTYPES_H + +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Types.h" + +namespace mlir { +namespace vector { + +/// A vector type containing at least one scalable dimension. +class ScalableVectorType : public VectorType { +public: + using VectorType::VectorType; + + static bool classof(Type type) { + auto vecTy = llvm::dyn_cast(type); + if (!vecTy) + return false; + return vecTy.isScalable(); + } +}; + +/// A vector type with no scalable dimensions. +class FixedVectorType : public VectorType { +public: + using VectorType::VectorType; + static bool classof(Type type) { + auto vecTy = llvm::dyn_cast(type); + if (!vecTy) + return false; + return !vecTy.isScalable(); + } +}; + +} // namespace vector +} // namespace mlir + +#endif // MLIR_IR_VECTORTYPES_H diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp index 16b4e8eb4f022..e65258786a768 100644 --- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp +++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp @@ -21,6 +21,8 @@ #include "mlir/IR/OpImplementation.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" +#include "mlir/IR/VectorTypes.h" +#include "mlir/Support/LogicalResult.h" #include "llvm/ADT/APFloat.h" #include "llvm/ADT/APInt.h" @@ -224,8 +226,8 @@ LogicalResult arith::ConstantOp::verify() { // Note, we could relax this for vectors with 1 scalable dim, e.g.: // * arith.constant dense<[[3, 3], [1, 1]]> : vector<2 x [2] x i32> // However, this would most likely require updating the lowerings to LLVM. - auto vecType = dyn_cast(type); - if (vecType && vecType.isScalable() && !isa(getValue())) + if (isa(type) && + !isa(getValue())) return emitOpError( "intializing scalable vectors with elements attribute is not supported" " unless it's a vector splat");