diff --git a/mlir/include/mlir/IR/Types.h b/mlir/include/mlir/IR/Types.h index acd0f894abbbe..30e151ece4850 100644 --- a/mlir/include/mlir/IR/Types.h +++ b/mlir/include/mlir/IR/Types.h @@ -125,6 +125,7 @@ class Type { // Convenience predicates. This is only for floating point types, // derived types should use isa/dyn_cast. bool isIndex() const; + bool isFloat() const; bool isFloat4E2M1FN() const; bool isFloat6E2M3FN() const; bool isFloat6E3M2FN() const; @@ -164,10 +165,10 @@ class Type { /// Return true if this is a signless integer or index type. bool isSignlessIntOrIndex() const; - /// Return true if this is a signless integer, index, or float type. - bool isSignlessIntOrIndexOrFloat() const; /// Return true of this is a signless integer or a float type. bool isSignlessIntOrFloat() const; + /// Return true if this is a signless integer, index, or float type. + bool isSignlessIntOrIndexOrFloat() const; /// Return true if this is an integer (of any signedness) or an index type. bool isIntOrIndex() const; diff --git a/mlir/lib/IR/Types.cpp b/mlir/lib/IR/Types.cpp index e190902b2e489..5914faa327ad2 100644 --- a/mlir/lib/IR/Types.cpp +++ b/mlir/lib/IR/Types.cpp @@ -63,6 +63,8 @@ bool Type::isF128() const { return llvm::isa(*this); } bool Type::isIndex() const { return llvm::isa(*this); } +bool Type::isFloat() const { return llvm::isa(*this); } + bool Type::isInteger() const { return llvm::isa(*this); } /// Return true if this is an integer type with the specified width. @@ -109,26 +111,22 @@ bool Type::isUnsignedInteger(unsigned width) const { } bool Type::isSignlessIntOrIndex() const { - return isSignlessInteger() || llvm::isa(*this); -} - -bool Type::isSignlessIntOrIndexOrFloat() const { - return isSignlessInteger() || llvm::isa(*this); + return isSignlessInteger() || isIndex(); } bool Type::isSignlessIntOrFloat() const { - return isSignlessInteger() || llvm::isa(*this); + return isSignlessInteger() || isFloat(); } -bool Type::isIntOrIndex() const { - return llvm::isa(*this) || isIndex(); +bool Type::isSignlessIntOrIndexOrFloat() const { + return isSignlessIntOrIndex() || isFloat(); } -bool Type::isIntOrFloat() const { - return llvm::isa(*this); -} +bool Type::isIntOrIndex() const { return isInteger() || isIndex(); } + +bool Type::isIntOrFloat() const { return isInteger() || isFloat(); } -bool Type::isIntOrIndexOrFloat() const { return isIntOrFloat() || isIndex(); } +bool Type::isIntOrIndexOrFloat() const { return isIntOrIndex() || isFloat(); } unsigned Type::getIntOrFloatBitWidth() const { assert(isIntOrFloat() && "only integers and floats have a bitwidth");