-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[mlir][spirv] Add bfloat16 support #141458
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
45349e6
d9815d4
80565f0
5d2984b
3a4b936
40ae919
47ba696
49e8e86
3ed46d2
4248d7c
f2ba086
c0fee39
e01695b
ef930f4
2f8ce7d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -344,6 +344,7 @@ def SPV_KHR_subgroup_rotate : I32EnumAttrCase<"SPV_KHR_subgroup | |
| def SPV_KHR_non_semantic_info : I32EnumAttrCase<"SPV_KHR_non_semantic_info", 29>; | ||
| def SPV_KHR_terminate_invocation : I32EnumAttrCase<"SPV_KHR_terminate_invocation", 30>; | ||
| def SPV_KHR_cooperative_matrix : I32EnumAttrCase<"SPV_KHR_cooperative_matrix", 31>; | ||
| def SPV_KHR_bfloat16 : I32EnumAttrCase<"SPV_KHR_bfloat16", 32>; | ||
|
|
||
| def SPV_EXT_demote_to_helper_invocation : I32EnumAttrCase<"SPV_EXT_demote_to_helper_invocation", 1000>; | ||
| def SPV_EXT_descriptor_indexing : I32EnumAttrCase<"SPV_EXT_descriptor_indexing", 1001>; | ||
|
|
@@ -436,7 +437,7 @@ def SPIRV_ExtensionAttr : | |
| SPV_KHR_fragment_shader_barycentric, SPV_KHR_ray_cull_mask, | ||
| SPV_KHR_uniform_group_instructions, SPV_KHR_subgroup_rotate, | ||
| SPV_KHR_non_semantic_info, SPV_KHR_terminate_invocation, | ||
| SPV_KHR_cooperative_matrix, | ||
| SPV_KHR_cooperative_matrix, SPV_KHR_bfloat16, | ||
| SPV_EXT_demote_to_helper_invocation, SPV_EXT_descriptor_indexing, | ||
| SPV_EXT_fragment_fully_covered, SPV_EXT_fragment_invocation_density, | ||
| SPV_EXT_fragment_shader_interlock, SPV_EXT_physical_storage_buffer, | ||
|
|
@@ -1412,6 +1413,23 @@ def SPIRV_C_ShaderStereoViewNV : I32EnumAttrCase<"Shade | |
| Extension<[SPV_NV_stereo_view_rendering]> | ||
| ]; | ||
| } | ||
| def SPIRV_C_BFloat16TypeKHR : I32EnumAttrCase<"BFloat16TypeKHR", 5116> { | ||
| list<Availability> availability = [ | ||
| Extension<[SPV_KHR_bfloat16]> | ||
| ]; | ||
| } | ||
| def SPIRV_C_BFloat16DotProductKHR : I32EnumAttrCase<"BFloat16DotProductKHR", 5117> { | ||
| list<I32EnumAttrCase> implies = [SPIRV_C_BFloat16TypeKHR]; | ||
| list<Availability> availability = [ | ||
| Extension<[SPV_KHR_bfloat16]> | ||
| ]; | ||
| } | ||
| def SPIRV_C_BFloat16CooperativeMatrixKHR : I32EnumAttrCase<"BFloat16CooperativeMatrixKHR", 5118> { | ||
| list<I32EnumAttrCase> implies = [SPIRV_C_BFloat16TypeKHR, SPIRV_C_CooperativeMatrixKHR]; | ||
| list<Availability> availability = [ | ||
| Extension<[SPV_KHR_bfloat16]> | ||
| ]; | ||
| } | ||
|
|
||
| def SPIRV_C_Bfloat16ConversionINTEL : I32EnumAttrCase<"Bfloat16ConversionINTEL", 6115> { | ||
| list<Availability> availability = [ | ||
|
|
@@ -1518,7 +1536,8 @@ def SPIRV_CapabilityAttr : | |
| SPIRV_C_StorageTexelBufferArrayNonUniformIndexing, | ||
| SPIRV_C_ShaderViewportIndexLayerEXT, SPIRV_C_ShaderViewportMaskNV, | ||
| SPIRV_C_ShaderStereoViewNV, SPIRV_C_Bfloat16ConversionINTEL, | ||
| SPIRV_C_CacheControlsINTEL | ||
| SPIRV_C_CacheControlsINTEL, SPIRV_C_BFloat16TypeKHR, | ||
| SPIRV_C_BFloat16DotProductKHR, SPIRV_C_BFloat16CooperativeMatrixKHR, | ||
| ]>; | ||
|
|
||
| def SPIRV_AM_Logical : I32EnumAttrCase<"Logical", 0>; | ||
|
|
@@ -3217,6 +3236,16 @@ def SPIRV_ExecutionModelAttr : | |
| SPIRV_EM_TaskEXT, SPIRV_EM_MeshEXT | ||
| ]>; | ||
|
|
||
| def SPIRV_FPE_BFloat16KHR : I32EnumAttrCase<"BFloat16KHR", 0> { | ||
| list<Availability> availability = [ | ||
| Capability<[SPIRV_C_BFloat16TypeKHR]> | ||
| ]; | ||
| } | ||
| def SPIRV_FPEncodingAttr : | ||
| SPIRV_I32EnumAttr<"FPEncoding", "valid SPIR-V FPEncoding", "f_p_encoding", [ | ||
| SPIRV_FPE_BFloat16KHR | ||
| ]>; | ||
|
|
||
| def SPIRV_FC_None : I32BitEnumAttrCaseNone<"None">; | ||
| def SPIRV_FC_Inline : I32BitEnumAttrCaseBit<"Inline", 0>; | ||
| def SPIRV_FC_DontInline : I32BitEnumAttrCaseBit<"DontInline", 1>; | ||
|
|
@@ -4163,8 +4192,9 @@ def SPIRV_Int32 : TypeAlias<I32, "Int32">; | |
| def SPIRV_Float32 : TypeAlias<F32, "Float32">; | ||
| def SPIRV_Float : FloatOfWidths<[16, 32, 64]>; | ||
| def SPIRV_Float16or32 : FloatOfWidths<[16, 32]>; | ||
| def SPIRV_FloatOrBFloat16 : AnyTypeOf<[SPIRV_Float, BF16]>; | ||
| def SPIRV_Vector : VectorOfLengthAndType<[2, 3, 4, 8, 16], | ||
| [SPIRV_Bool, SPIRV_Integer, SPIRV_Float]>; | ||
| [SPIRV_Bool, SPIRV_Integer, SPIRV_FloatOrBFloat16]>; | ||
|
||
| // Component type check is done in the type parser for the following SPIR-V | ||
| // dialect-specific types so we use "Any" here. | ||
| def SPIRV_AnyPtr : DialectType<SPIRV_Dialect, SPIRV_IsPtrType, | ||
|
|
@@ -4194,9 +4224,9 @@ def SPIRV_Composite : | |
| AnyTypeOf<[SPIRV_Vector, SPIRV_AnyArray, SPIRV_AnyRTArray, SPIRV_AnyStruct, | ||
| SPIRV_AnyCooperativeMatrix, SPIRV_AnyMatrix]>; | ||
| def SPIRV_Type : AnyTypeOf<[ | ||
| SPIRV_Void, SPIRV_Bool, SPIRV_Integer, SPIRV_Float, SPIRV_Vector, | ||
| SPIRV_Void, SPIRV_Bool, SPIRV_Integer, SPIRV_FloatOrBFloat16, SPIRV_Vector, | ||
| SPIRV_AnyPtr, SPIRV_AnyArray, SPIRV_AnyRTArray, SPIRV_AnyStruct, | ||
| SPIRV_AnyCooperativeMatrix, SPIRV_AnyMatrix, SPIRV_AnySampledImage | ||
| SPIRV_AnyCooperativeMatrix, SPIRV_AnyMatrix, SPIRV_AnySampledImage, | ||
fairywreath marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| ]>; | ||
|
|
||
| def SPIRV_SignedInt : SignedIntOfWidths<[8, 16, 32, 64]>; | ||
|
|
@@ -4215,16 +4245,24 @@ class SPIRV_MatrixOfType<list<Type> allowedTypes> : | |
| class SPIRV_VectorOf<Type type> : | ||
| VectorOfLengthAndType<[2, 3, 4, 8, 16], [type]>; | ||
|
|
||
| class SPIRV_CoopMatrixOf<Type type> : | ||
fairywreath marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| SPIRV_CoopMatrixOfType<[type]>; | ||
|
|
||
| class SPIRV_ScalarOrVectorOf<Type type> : | ||
| AnyTypeOf<[type, SPIRV_VectorOf<type>]>; | ||
|
|
||
| class SPIRV_ScalarOrVectorOrCoopMatrixOf<Type type> : | ||
| AnyTypeOf<[type, SPIRV_VectorOf<type>, | ||
| SPIRV_CoopMatrixOfType<[type]>]>; | ||
| SPIRV_CoopMatrixOf<type>]>; | ||
|
|
||
| class SPIRV_ScalarOrVectorOfOrCoopMatrixOf<Type scalarVectorType, | ||
| Type coopMatrixType> : | ||
| AnyTypeOf<[scalarVectorType, SPIRV_VectorOf<scalarVectorType>, | ||
| SPIRV_CoopMatrixOf<coopMatrixType>]>; | ||
|
|
||
| class SPIRV_MatrixOrCoopMatrixOf<Type type> : | ||
| AnyTypeOf<[SPIRV_AnyMatrix, | ||
| SPIRV_CoopMatrixOfType<[type]>]>; | ||
| SPIRV_CoopMatrixOf<type>]>; | ||
|
|
||
| class SPIRV_MatrixOf<Type type> : | ||
| SPIRV_MatrixOfType<[type]>; | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -175,10 +175,7 @@ static Type parseAndVerifyType(SPIRVDialect const &dialect, | |
|
|
||
| // Check other allowed types | ||
| if (auto t = llvm::dyn_cast<FloatType>(type)) { | ||
| if (type.isBF16()) { | ||
| parser.emitError(typeLoc, "cannot use 'bf16' to compose SPIR-V types"); | ||
| return Type(); | ||
| } | ||
| // TODO: All float types are allowed for now, but this should be fixed. | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I will address this in a separate PR.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you please elaborate what needs to be fixed here?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The current behavior does not error out on bitwidths that are invalid for SPIRV (eg. F80, F128) and non-standard formats (eg. E3M2). Do you think it's better to address this here or in a separate PR?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In my opinion it's okay to address it later. In fact I think it's preferable. Currently the code doesn't do any checks anyway, other than checking for bf16, so adding a proper check would be out of scope of this PR. |
||
| } else if (auto t = llvm::dyn_cast<IntegerType>(type)) { | ||
| if (!ScalarType::isValid(t)) { | ||
| parser.emitError(typeLoc, | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.