Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 73 additions & 29 deletions mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ class SPIRV_ArithmeticBinaryOp<string mnemonic, Type type,
SPIRV_BinaryOp<mnemonic, type, type,
!listconcat(traits,
[Pure, SameOperandsAndResultType])> {
// In addition to normal types arithmetic instructions can support cooperative
// matrix.
// TODO: Arithmetic operations that use this definition do not support cooperative matrices,
// these need to be fixed.
let arguments = (ins
SPIRV_ScalarOrVectorOrCoopMatrixOf<type>:$operand1,
SPIRV_ScalarOrVectorOrCoopMatrixOf<type>:$operand2
Expand All @@ -37,20 +37,43 @@ class SPIRV_ArithmeticBinaryOp<string mnemonic, Type type,
let assemblyFormat = "operands attr-dict `:` type($result)";
}

class SPIRV_ArithmeticUnaryOp<string mnemonic, Type type,
class SPIRV_ArithmeticWithCoopMatrixBinaryOp<string mnemonic,
Type scalarVectorType,
Type coopMatrixType,
list<Trait> traits = []> :
// Operands type same as result type.
SPIRV_BinaryOp<mnemonic, coopMatrixType, coopMatrixType,
!listconcat(traits,
[Pure, SameOperandsAndResultType])> {
// In addition to normal types these arithmetic instructions can support
// cooperative matrix.
let arguments = (ins
SPIRV_ScalarOrVectorOfOrCoopMatrixOf<scalarVectorType, coopMatrixType>:$operand1,
SPIRV_ScalarOrVectorOfOrCoopMatrixOf<scalarVectorType, coopMatrixType>:$operand2
);

let results = (outs
SPIRV_ScalarOrVectorOfOrCoopMatrixOf<scalarVectorType, coopMatrixType>:$result
);
let assemblyFormat = "operands attr-dict `:` type($result)";
}

class SPIRV_ArithmeticUnaryOp<string mnemonic,
Type scalarVectorType,
Type coopMatrixType,
list<Trait> traits = []> :
// Operand type same as result type.
SPIRV_UnaryOp<mnemonic, type, type,
SPIRV_UnaryOp<mnemonic, coopMatrixType, coopMatrixType,
!listconcat(traits,
[Pure, SameOperandsAndResultType])> {
// In addition to normal types arithmetic instructions can support cooperative
// matrix.
let arguments = (ins
SPIRV_ScalarOrVectorOrCoopMatrixOf<type>:$operand
SPIRV_ScalarOrVectorOfOrCoopMatrixOf<scalarVectorType, coopMatrixType>:$operand1
);

let results = (outs
SPIRV_ScalarOrVectorOrCoopMatrixOf<type>:$result
SPIRV_ScalarOrVectorOfOrCoopMatrixOf<scalarVectorType, coopMatrixType>:$result
);
let assemblyFormat = "operands attr-dict `:` type($result)";
}
Expand Down Expand Up @@ -82,7 +105,10 @@ class SPIRV_ArithmeticExtendedBinaryOp<string mnemonic,

// -----

def SPIRV_FAddOp : SPIRV_ArithmeticBinaryOp<"FAdd", SPIRV_Float, [Commutative]> {
def SPIRV_FAddOp : SPIRV_ArithmeticWithCoopMatrixBinaryOp<"FAdd",
SPIRV_Float,
SPIRV_FloatOrBFloat16,
[Commutative]> {
let summary = "Floating-point addition of Operand 1 and Operand 2.";

let description = [{
Expand All @@ -104,7 +130,10 @@ def SPIRV_FAddOp : SPIRV_ArithmeticBinaryOp<"FAdd", SPIRV_Float, [Commutative]>

// -----

def SPIRV_FDivOp : SPIRV_ArithmeticBinaryOp<"FDiv", SPIRV_Float, []> {
def SPIRV_FDivOp : SPIRV_ArithmeticWithCoopMatrixBinaryOp<"FDiv",
SPIRV_Float,
SPIRV_FloatOrBFloat16,
[]> {
let summary = "Floating-point division of Operand 1 divided by Operand 2.";

let description = [{
Expand Down Expand Up @@ -154,7 +183,10 @@ def SPIRV_FModOp : SPIRV_ArithmeticBinaryOp<"FMod", SPIRV_Float, []> {

// -----

def SPIRV_FMulOp : SPIRV_ArithmeticBinaryOp<"FMul", SPIRV_Float, [Commutative]> {
def SPIRV_FMulOp : SPIRV_ArithmeticWithCoopMatrixBinaryOp<"FMul",
SPIRV_Float,
SPIRV_FloatOrBFloat16,
[Commutative]> {
let summary = "Floating-point multiplication of Operand 1 and Operand 2.";

let description = [{
Expand All @@ -176,7 +208,10 @@ def SPIRV_FMulOp : SPIRV_ArithmeticBinaryOp<"FMul", SPIRV_Float, [Commutative]>

// -----

def SPIRV_FNegateOp : SPIRV_ArithmeticUnaryOp<"FNegate", SPIRV_Float, []> {
def SPIRV_FNegateOp : SPIRV_ArithmeticUnaryOp<"FNegate",
SPIRV_Float,
SPIRV_FloatOrBFloat16,
[]> {
let summary = [{
Inverts the sign bit of Operand. (Note, however, that OpFNegate is still
considered a floating-point instruction, and so is subject to the
Expand Down Expand Up @@ -229,7 +264,10 @@ def SPIRV_FRemOp : SPIRV_ArithmeticBinaryOp<"FRem", SPIRV_Float, []> {

// -----

def SPIRV_FSubOp : SPIRV_ArithmeticBinaryOp<"FSub", SPIRV_Float, []> {
def SPIRV_FSubOp : SPIRV_ArithmeticWithCoopMatrixBinaryOp<"FSub",
SPIRV_Float,
SPIRV_FloatOrBFloat16,
[]> {
let summary = "Floating-point subtraction of Operand 2 from Operand 1.";

let description = [{
Expand All @@ -251,9 +289,10 @@ def SPIRV_FSubOp : SPIRV_ArithmeticBinaryOp<"FSub", SPIRV_Float, []> {

// -----

def SPIRV_IAddOp : SPIRV_ArithmeticBinaryOp<"IAdd",
SPIRV_Integer,
[Commutative, UsableInSpecConstantOp]> {
def SPIRV_IAddOp : SPIRV_ArithmeticWithCoopMatrixBinaryOp<"IAdd",
SPIRV_Integer,
SPIRV_Integer,
[Commutative, UsableInSpecConstantOp]> {
let summary = "Integer addition of Operand 1 and Operand 2.";

let description = [{
Expand Down Expand Up @@ -322,9 +361,10 @@ def SPIRV_IAddCarryOp : SPIRV_ArithmeticExtendedBinaryOp<"IAddCarry",

// -----

def SPIRV_IMulOp : SPIRV_ArithmeticBinaryOp<"IMul",
SPIRV_Integer,
[Commutative, UsableInSpecConstantOp]> {
def SPIRV_IMulOp : SPIRV_ArithmeticWithCoopMatrixBinaryOp<"IMul",
SPIRV_Integer,
SPIRV_Integer,
[Commutative, UsableInSpecConstantOp]> {
let summary = "Integer multiplication of Operand 1 and Operand 2.";

let description = [{
Expand Down Expand Up @@ -354,9 +394,10 @@ def SPIRV_IMulOp : SPIRV_ArithmeticBinaryOp<"IMul",

// -----

def SPIRV_ISubOp : SPIRV_ArithmeticBinaryOp<"ISub",
SPIRV_Integer,
[UsableInSpecConstantOp]> {
def SPIRV_ISubOp : SPIRV_ArithmeticWithCoopMatrixBinaryOp<"ISub",
SPIRV_Integer,
SPIRV_Integer,
[UsableInSpecConstantOp]> {
let summary = "Integer subtraction of Operand 2 from Operand 1.";

let description = [{
Expand Down Expand Up @@ -445,12 +486,12 @@ def SPIRV_DotOp : SPIRV_Op<"Dot",
}];

let arguments = (ins
SPIRV_VectorOf<SPIRV_Float>:$vector1,
SPIRV_VectorOf<SPIRV_Float>:$vector2
SPIRV_VectorOf<SPIRV_FloatOrBFloat16>:$vector1,
SPIRV_VectorOf<SPIRV_FloatOrBFloat16>:$vector2
);

let results = (outs
SPIRV_Float:$result
SPIRV_FloatOrBFloat16:$result
);

let assemblyFormat = "operands attr-dict `:` type($vector1) `->` type($result)";
Expand All @@ -460,9 +501,10 @@ def SPIRV_DotOp : SPIRV_Op<"Dot",

// -----

def SPIRV_SDivOp : SPIRV_ArithmeticBinaryOp<"SDiv",
SPIRV_Integer,
[UsableInSpecConstantOp]> {
def SPIRV_SDivOp : SPIRV_ArithmeticWithCoopMatrixBinaryOp<"SDiv",
SPIRV_Integer,
SPIRV_Integer,
[UsableInSpecConstantOp]> {
let summary = "Signed-integer division of Operand 1 divided by Operand 2.";

let description = [{
Expand Down Expand Up @@ -560,6 +602,7 @@ def SPIRV_SMulExtendedOp : SPIRV_ArithmeticExtendedBinaryOp<"SMulExtended",
// -----

def SPIRV_SNegateOp : SPIRV_ArithmeticUnaryOp<"SNegate",
SPIRV_Integer,
SPIRV_Integer,
[UsableInSpecConstantOp]> {
let summary = "Signed-integer subtract of Operand from zero.";
Expand Down Expand Up @@ -622,9 +665,10 @@ def SPIRV_SRemOp : SPIRV_ArithmeticBinaryOp<"SRem",

// -----

def SPIRV_UDivOp : SPIRV_ArithmeticBinaryOp<"UDiv",
SPIRV_Integer,
[UnsignedOp, UsableInSpecConstantOp]> {
def SPIRV_UDivOp : SPIRV_ArithmeticWithCoopMatrixBinaryOp<"UDiv",
SPIRV_Integer,
SPIRV_Integer,
[UnsignedOp, UsableInSpecConstantOp]> {
let summary = "Unsigned-integer division of Operand 1 divided by Operand 2.";

let description = [{
Expand Down
52 changes: 45 additions & 7 deletions mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,7 @@ def SPV_KHR_subgroup_rotate : I32EnumAttrCase<"SPV_KHR_subgroup
def SPV_KHR_non_semantic_info : I32EnumAttrCase<"SPV_KHR_non_semantic_info", 29>;
def SPV_KHR_terminate_invocation : I32EnumAttrCase<"SPV_KHR_terminate_invocation", 30>;
def SPV_KHR_cooperative_matrix : I32EnumAttrCase<"SPV_KHR_cooperative_matrix", 31>;
def SPV_KHR_bfloat16 : I32EnumAttrCase<"SPV_KHR_bfloat16", 32>;

def SPV_EXT_demote_to_helper_invocation : I32EnumAttrCase<"SPV_EXT_demote_to_helper_invocation", 1000>;
def SPV_EXT_descriptor_indexing : I32EnumAttrCase<"SPV_EXT_descriptor_indexing", 1001>;
Expand Down Expand Up @@ -436,7 +437,7 @@ def SPIRV_ExtensionAttr :
SPV_KHR_fragment_shader_barycentric, SPV_KHR_ray_cull_mask,
SPV_KHR_uniform_group_instructions, SPV_KHR_subgroup_rotate,
SPV_KHR_non_semantic_info, SPV_KHR_terminate_invocation,
SPV_KHR_cooperative_matrix,
SPV_KHR_cooperative_matrix, SPV_KHR_bfloat16,
SPV_EXT_demote_to_helper_invocation, SPV_EXT_descriptor_indexing,
SPV_EXT_fragment_fully_covered, SPV_EXT_fragment_invocation_density,
SPV_EXT_fragment_shader_interlock, SPV_EXT_physical_storage_buffer,
Expand Down Expand Up @@ -1412,6 +1413,23 @@ def SPIRV_C_ShaderStereoViewNV : I32EnumAttrCase<"Shade
Extension<[SPV_NV_stereo_view_rendering]>
];
}
def SPIRV_C_BFloat16TypeKHR : I32EnumAttrCase<"BFloat16TypeKHR", 5116> {
list<Availability> availability = [
Extension<[SPV_KHR_bfloat16]>
];
}
def SPIRV_C_BFloat16DotProductKHR : I32EnumAttrCase<"BFloat16DotProductKHR", 5117> {
list<I32EnumAttrCase> implies = [SPIRV_C_BFloat16TypeKHR];
list<Availability> availability = [
Extension<[SPV_KHR_bfloat16]>
];
}
def SPIRV_C_BFloat16CooperativeMatrixKHR : I32EnumAttrCase<"BFloat16CooperativeMatrixKHR", 5118> {
list<I32EnumAttrCase> implies = [SPIRV_C_BFloat16TypeKHR, SPIRV_C_CooperativeMatrixKHR];
list<Availability> availability = [
Extension<[SPV_KHR_bfloat16]>
];
}

def SPIRV_C_Bfloat16ConversionINTEL : I32EnumAttrCase<"Bfloat16ConversionINTEL", 6115> {
list<Availability> availability = [
Expand Down Expand Up @@ -1518,7 +1536,8 @@ def SPIRV_CapabilityAttr :
SPIRV_C_StorageTexelBufferArrayNonUniformIndexing,
SPIRV_C_ShaderViewportIndexLayerEXT, SPIRV_C_ShaderViewportMaskNV,
SPIRV_C_ShaderStereoViewNV, SPIRV_C_Bfloat16ConversionINTEL,
SPIRV_C_CacheControlsINTEL
SPIRV_C_CacheControlsINTEL, SPIRV_C_BFloat16TypeKHR,
SPIRV_C_BFloat16DotProductKHR, SPIRV_C_BFloat16CooperativeMatrixKHR,
]>;

def SPIRV_AM_Logical : I32EnumAttrCase<"Logical", 0>;
Expand Down Expand Up @@ -3217,6 +3236,16 @@ def SPIRV_ExecutionModelAttr :
SPIRV_EM_TaskEXT, SPIRV_EM_MeshEXT
]>;

def SPIRV_FPE_BFloat16KHR : I32EnumAttrCase<"BFloat16KHR", 0> {
list<Availability> availability = [
Capability<[SPIRV_C_BFloat16TypeKHR]>
];
}
def SPIRV_FPEncodingAttr :
SPIRV_I32EnumAttr<"FPEncoding", "valid SPIR-V FPEncoding", "f_p_encoding", [
SPIRV_FPE_BFloat16KHR
]>;

def SPIRV_FC_None : I32BitEnumAttrCaseNone<"None">;
def SPIRV_FC_Inline : I32BitEnumAttrCaseBit<"Inline", 0>;
def SPIRV_FC_DontInline : I32BitEnumAttrCaseBit<"DontInline", 1>;
Expand Down Expand Up @@ -4163,8 +4192,9 @@ def SPIRV_Int32 : TypeAlias<I32, "Int32">;
def SPIRV_Float32 : TypeAlias<F32, "Float32">;
def SPIRV_Float : FloatOfWidths<[16, 32, 64]>;
def SPIRV_Float16or32 : FloatOfWidths<[16, 32]>;
def SPIRV_FloatOrBFloat16 : AnyTypeOf<[SPIRV_Float, BF16]>;
def SPIRV_Vector : VectorOfLengthAndType<[2, 3, 4, 8, 16],
[SPIRV_Bool, SPIRV_Integer, SPIRV_Float]>;
[SPIRV_Bool, SPIRV_Integer, SPIRV_FloatOrBFloat16]>;
Copy link
Contributor Author

@fairywreath fairywreath May 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can argue that bf16 should be part of SPIRV_Float. The problem here is that bf16 usage in SPIRV is very limited while SPIRV_Float(i.e regular floats) is used widely in the codebase for other ops(eg. texture sampling and regular arithmetic insts). I chose to leave SPIRV_Float to minimize the amount of changes(and to not introduce something like SPIRV_ArithmeticFloat). Please let me know if you think there is a cleaner solution to this.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with this, it might be better to leave them alone. what about something like this:

def SPIRV_BFloat16KHR : TypeAlias<BF16, "BFloat16">;
def SPIRV_Float : AnyTypeOf<[F16, F32, F64]>;
def SPIRV_Float16or32 : AnyTypeOf<[F16, F32]>;
// Use this type for all kinds of floats.
def SPIRV_AnyFloat : AnyTypeOf<[SPIRV_BFloat16KHR, SPIRV_Float]>;
.....
def SPIRV_Vector : VectorOfLengthRangeAndType<[2, 0xFFFFFFFF],
                                       [SPIRV_Bool, SPIRV_Integer, SPIRV_Float, SPIRV_BFloat16KHR]>;
......

def SPIRV_Numerical : AnyTypeOf<[SPIRV_Integer, SPIRV_Float, SPIRV_BFloat16KHR]>;
.........

def SPIRV_Type : AnyTypeOf<[
    SPIRV_Void, SPIRV_Bool, SPIRV_Integer, SPIRV_Float, SPIRV_BFloat16KHR, SPIRV_Vector,
     SPIRV_AnyPtr, SPIRV_AnyArray, SPIRV_AnyRTArray, SPIRV_AnyStruct,
     SPIRV_AnyCooperativeMatrix, SPIRV_AnyMatrix, SPIRV_AnySampledImage
   ]>;

// Component type check is done in the type parser for the following SPIR-V
// dialect-specific types so we use "Any" here.
def SPIRV_AnyPtr : DialectType<SPIRV_Dialect, SPIRV_IsPtrType,
Expand Down Expand Up @@ -4194,9 +4224,9 @@ def SPIRV_Composite :
AnyTypeOf<[SPIRV_Vector, SPIRV_AnyArray, SPIRV_AnyRTArray, SPIRV_AnyStruct,
SPIRV_AnyCooperativeMatrix, SPIRV_AnyMatrix]>;
def SPIRV_Type : AnyTypeOf<[
SPIRV_Void, SPIRV_Bool, SPIRV_Integer, SPIRV_Float, SPIRV_Vector,
SPIRV_Void, SPIRV_Bool, SPIRV_Integer, SPIRV_FloatOrBFloat16, SPIRV_Vector,
SPIRV_AnyPtr, SPIRV_AnyArray, SPIRV_AnyRTArray, SPIRV_AnyStruct,
SPIRV_AnyCooperativeMatrix, SPIRV_AnyMatrix, SPIRV_AnySampledImage
SPIRV_AnyCooperativeMatrix, SPIRV_AnyMatrix, SPIRV_AnySampledImage,
]>;

def SPIRV_SignedInt : SignedIntOfWidths<[8, 16, 32, 64]>;
Expand All @@ -4215,16 +4245,24 @@ class SPIRV_MatrixOfType<list<Type> allowedTypes> :
class SPIRV_VectorOf<Type type> :
VectorOfLengthAndType<[2, 3, 4, 8, 16], [type]>;

class SPIRV_CoopMatrixOf<Type type> :
SPIRV_CoopMatrixOfType<[type]>;

class SPIRV_ScalarOrVectorOf<Type type> :
AnyTypeOf<[type, SPIRV_VectorOf<type>]>;

class SPIRV_ScalarOrVectorOrCoopMatrixOf<Type type> :
AnyTypeOf<[type, SPIRV_VectorOf<type>,
SPIRV_CoopMatrixOfType<[type]>]>;
SPIRV_CoopMatrixOf<type>]>;

class SPIRV_ScalarOrVectorOfOrCoopMatrixOf<Type scalarVectorType,
Type coopMatrixType> :
AnyTypeOf<[scalarVectorType, SPIRV_VectorOf<scalarVectorType>,
SPIRV_CoopMatrixOf<coopMatrixType>]>;

class SPIRV_MatrixOrCoopMatrixOf<Type type> :
AnyTypeOf<[SPIRV_AnyMatrix,
SPIRV_CoopMatrixOfType<[type]>]>;
SPIRV_CoopMatrixOf<type>]>;

class SPIRV_MatrixOf<Type type> :
SPIRV_MatrixOfType<[type]>;
Expand Down
12 changes: 6 additions & 6 deletions mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCastOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def SPIRV_BitcastOp : SPIRV_Op<"Bitcast", [Pure]> {

// -----

def SPIRV_ConvertFToSOp : SPIRV_CastOp<"ConvertFToS", SPIRV_Integer, SPIRV_Float, []> {
def SPIRV_ConvertFToSOp : SPIRV_CastOp<"ConvertFToS", SPIRV_Integer, SPIRV_FloatOrBFloat16, []> {
let summary = [{
Convert value numerically from floating point to signed integer, with
round toward 0.0.
Expand All @@ -111,7 +111,7 @@ def SPIRV_ConvertFToSOp : SPIRV_CastOp<"ConvertFToS", SPIRV_Integer, SPIRV_Float

// -----

def SPIRV_ConvertFToUOp : SPIRV_CastOp<"ConvertFToU", SPIRV_Integer, SPIRV_Float, []> {
def SPIRV_ConvertFToUOp : SPIRV_CastOp<"ConvertFToU", SPIRV_Integer, SPIRV_FloatOrBFloat16, []> {
let summary = [{
Convert value numerically from floating point to unsigned integer, with
round toward 0.0.
Expand All @@ -138,7 +138,7 @@ def SPIRV_ConvertFToUOp : SPIRV_CastOp<"ConvertFToU", SPIRV_Integer, SPIRV_Float
// -----

def SPIRV_ConvertSToFOp : SPIRV_CastOp<"ConvertSToF",
SPIRV_Float,
SPIRV_FloatOrBFloat16,
SPIRV_Integer,
[SignedOp]> {
let summary = [{
Expand All @@ -165,7 +165,7 @@ def SPIRV_ConvertSToFOp : SPIRV_CastOp<"ConvertSToF",
// -----

def SPIRV_ConvertUToFOp : SPIRV_CastOp<"ConvertUToF",
SPIRV_Float,
SPIRV_FloatOrBFloat16,
SPIRV_Integer,
[UnsignedOp]> {
let summary = [{
Expand All @@ -192,8 +192,8 @@ def SPIRV_ConvertUToFOp : SPIRV_CastOp<"ConvertUToF",
// -----

def SPIRV_FConvertOp : SPIRV_CastOp<"FConvert",
SPIRV_Float,
SPIRV_Float,
SPIRV_FloatOrBFloat16,
SPIRV_FloatOrBFloat16,
[UsableInSpecConstantOp]> {
let summary = [{
Convert value numerically from one floating-point width to another
Expand Down
5 changes: 1 addition & 4 deletions mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -175,10 +175,7 @@ static Type parseAndVerifyType(SPIRVDialect const &dialect,

// Check other allowed types
if (auto t = llvm::dyn_cast<FloatType>(type)) {
if (type.isBF16()) {
parser.emitError(typeLoc, "cannot use 'bf16' to compose SPIR-V types");
return Type();
}
// TODO: All float types are allowed for now, but this should be fixed.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will address this in a separate PR.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please elaborate what needs to be fixed here?

Copy link
Contributor Author

@fairywreath fairywreath May 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The current behavior does not error out on bitwidths that are invalid for SPIRV (eg. F80, F128) and non-standard formats (eg. E3M2). Do you think it's better to address this here or in a separate PR?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In my opinion it's okay to address it later. In fact I think it's preferable. Currently the code doesn't do any checks anyway, other than checking for bf16, so adding a proper check would be out of scope of this PR.

} else if (auto t = llvm::dyn_cast<IntegerType>(type)) {
if (!ScalarType::isValid(t)) {
parser.emitError(typeLoc,
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -505,7 +505,7 @@ bool ScalarType::classof(Type type) {
}

bool ScalarType::isValid(FloatType type) {
return llvm::is_contained({16u, 32u, 64u}, type.getWidth()) && !type.isBF16();
return llvm::is_contained({16u, 32u, 64u}, type.getWidth());
}

bool ScalarType::isValid(IntegerType type) {
Expand Down
Loading