Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions mlir/lib/Interfaces/DataLayoutInterfaces.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ mlir::detail::getDefaultTypeSize(Type type, const DataLayout &dataLayout,
llvm::TypeSize
mlir::detail::getDefaultTypeSizeInBits(Type type, const DataLayout &dataLayout,
DataLayoutEntryListRef params) {
if (isa<IntegerType, FloatType>(type))
if (type.isIntOrFloat())
return llvm::TypeSize::getFixed(type.getIntOrFloatBitWidth());

if (auto ctype = dyn_cast<ComplexType>(type)) {
Expand Down Expand Up @@ -720,7 +720,7 @@ LogicalResult mlir::detail::verifyDataLayoutSpec(DataLayoutSpecInterface spec,
continue;
}

if (isa<IntegerType, FloatType>(sampleType)) {
if (sampleType.isIntOrFloat()) {
for (DataLayoutEntryInterface entry : kvp.second) {
auto value = dyn_cast<DenseIntElementsAttr>(entry.getValue());
if (!value || !value.getElementType().isSignlessInteger(64)) {
Expand Down
11 changes: 3 additions & 8 deletions mlir/lib/Target/LLVMIR/ModuleImport.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -759,11 +759,6 @@ void ModuleImport::setFastmathFlagsAttr(llvm::Instruction *inst,
iface->setAttr(iface.getFastmathAttrName(), attr);
}

/// Returns if `type` is a scalar integer or floating-point type.
static bool isScalarType(Type type) {
return isa<IntegerType, FloatType>(type);
}

/// Returns `type` if it is a builtin integer or floating-point vector type that
/// can be used to create an attribute or nullptr otherwise. If provided,
/// `arrayShape` is added to the shape of the vector to create an attribute that
Expand All @@ -781,7 +776,7 @@ static Type getVectorTypeForAttr(Type type, ArrayRef<int64_t> arrayShape = {}) {

// An LLVM dialect vector can only contain scalars.
Type elementType = LLVM::getVectorElementType(type);
if (!isScalarType(elementType))
if (!elementType.isIntOrFloat())
return {};

SmallVector<int64_t> shape(arrayShape);
Expand All @@ -794,7 +789,7 @@ Type ModuleImport::getBuiltinTypeForAttr(Type type) {
return {};

// Return builtin integer and floating-point types as is.
if (isScalarType(type))
if (type.isIntOrFloat())
return type;

// Return builtin vectors of integer and floating-point types as is.
Expand All @@ -808,7 +803,7 @@ Type ModuleImport::getBuiltinTypeForAttr(Type type) {
arrayShape.push_back(arrayType.getNumElements());
type = arrayType.getElementType();
}
if (isScalarType(type))
if (type.isIntOrFloat())
return RankedTensorType::get(arrayShape, type);
return getVectorTypeForAttr(type, arrayShape);
}
Expand Down