diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp index f5700059f68ee..877ac87fb0fe5 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp @@ -292,6 +292,8 @@ convertScalarType(const spirv::TargetEnv &targetEnv, } /// Converts a sub-byte integer `type` to i32 regardless of target environment. +/// Returns a nullptr for unsupported integer types, including non sub-byte +/// types. /// /// Note that we don't recognize sub-byte types in `spirv::ScalarType` and use /// the above given that these sub-byte types are not supported at all in @@ -299,6 +301,10 @@ convertScalarType(const spirv::TargetEnv &targetEnv, /// supported integer types. static Type convertSubByteIntegerType(const SPIRVConversionOptions &options, IntegerType type) { + if (type.getWidth() > 8) { + LLVM_DEBUG(llvm::dbgs() << "not a subbyte type\n"); + return nullptr; + } if (options.subByteTypeStorage != SPIRVSubByteTypeStorage::Packed) { LLVM_DEBUG(llvm::dbgs() << "unsupported sub-byte storage kind\n"); return nullptr; @@ -348,6 +354,9 @@ convertVectorType(const spirv::TargetEnv &targetEnv, } Type elementType = convertSubByteIntegerType(options, intType); + if (!elementType) + return nullptr; + if (type.getRank() <= 1 && type.getNumElements() == 1) return elementType; diff --git a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir index 24a0bab352c34..9d7ab2be096ef 100644 --- a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir +++ b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir @@ -60,6 +60,14 @@ func.func @int_vector4_invalid(%arg0: vector<2xi16>) { return } +// ----- + +func.func @int_vector_invalid_bitwidth(%arg0: vector<2xi12>) { + // expected-error @+1 {{failed to legalize operation 'arith.addi'}} + %0 = arith.addi %arg0, %arg0: vector<2xi12> + return +} + ///===----------------------------------------------------------------------===// // Constant ops //===----------------------------------------------------------------------===//