@@ -21,63 +21,63 @@ index 22d5afcd7738..de9e11493793 100644
21
21
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
22
22
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
23
23
@@ -82,7 +82,7 @@ class SPIRV_ArithmeticExtendedBinaryOp<string mnemonic,
24
-
24
+
25
25
// -----
26
-
26
+
27
27
- def SPIRV_FAddOp : SPIRV_ArithmeticBinaryOp<"FAdd", SPIRV_Float, [Commutative]> {
28
28
+ def SPIRV_FAddOp : SPIRV_ArithmeticBinaryOp<"FAdd", SPIRV_AnyFloat, [Commutative]> {
29
29
let summary = "Floating-point addition of Operand 1 and Operand 2.";
30
-
30
+
31
31
let description = [{
32
32
@@ -104,7 +104,7 @@ def SPIRV_FAddOp : SPIRV_ArithmeticBinaryOp<"FAdd", SPIRV_Float, [Commutative]>
33
-
33
+
34
34
// -----
35
-
35
+
36
36
- def SPIRV_FDivOp : SPIRV_ArithmeticBinaryOp<"FDiv", SPIRV_Float, []> {
37
37
+ def SPIRV_FDivOp : SPIRV_ArithmeticBinaryOp<"FDiv", SPIRV_AnyFloat, []> {
38
38
let summary = "Floating-point division of Operand 1 divided by Operand 2.";
39
-
39
+
40
40
let description = [{
41
41
@@ -154,7 +154,7 @@ def SPIRV_FModOp : SPIRV_ArithmeticBinaryOp<"FMod", SPIRV_Float, []> {
42
-
42
+
43
43
// -----
44
-
44
+
45
45
- def SPIRV_FMulOp : SPIRV_ArithmeticBinaryOp<"FMul", SPIRV_Float, [Commutative]> {
46
46
+ def SPIRV_FMulOp : SPIRV_ArithmeticBinaryOp<"FMul", SPIRV_AnyFloat, [Commutative]> {
47
47
let summary = "Floating-point multiplication of Operand 1 and Operand 2.";
48
-
48
+
49
49
let description = [{
50
50
@@ -229,7 +229,7 @@ def SPIRV_FRemOp : SPIRV_ArithmeticBinaryOp<"FRem", SPIRV_Float, []> {
51
-
51
+
52
52
// -----
53
-
53
+
54
54
- def SPIRV_FSubOp : SPIRV_ArithmeticBinaryOp<"FSub", SPIRV_Float, []> {
55
55
+ def SPIRV_FSubOp : SPIRV_ArithmeticBinaryOp<"FSub", SPIRV_AnyFloat, []> {
56
56
let summary = "Floating-point subtraction of Operand 2 from Operand 1.";
57
-
57
+
58
58
let description = [{
59
59
@@ -450,7 +450,7 @@ def SPIRV_DotOp : SPIRV_Op<"Dot",
60
60
);
61
-
61
+
62
62
let results = (outs
63
63
- SPIRV_Float:$result
64
64
+ SPIRV_AnyFloat:$result
65
65
);
66
-
66
+
67
67
let assemblyFormat = "operands attr-dict `:` type($vector1) `->` type($result)";
68
68
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
69
- index 04952dd1dc61..6c9c348490ab 100644
69
+ index ddaeb13ef253..336bdcfb7a48 100644
70
70
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
71
71
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
72
72
@@ -343,6 +343,7 @@ def SPV_KHR_subgroup_rotate : I32EnumAttrCase<"SPV_KHR_subgroup
73
73
def SPV_KHR_non_semantic_info : I32EnumAttrCase<"SPV_KHR_non_semantic_info", 29>;
74
74
def SPV_KHR_terminate_invocation : I32EnumAttrCase<"SPV_KHR_terminate_invocation", 30>;
75
75
def SPV_KHR_cooperative_matrix : I32EnumAttrCase<"SPV_KHR_cooperative_matrix", 31>;
76
76
+ def SPV_KHR_bfloat16 : I32EnumAttrCase<"SPV_KHR_bfloat16", 32>;
77
-
77
+
78
78
def SPV_EXT_demote_to_helper_invocation : I32EnumAttrCase<"SPV_EXT_demote_to_helper_invocation", 1000>;
79
79
def SPV_EXT_descriptor_indexing : I32EnumAttrCase<"SPV_EXT_descriptor_indexing", 1001>;
80
- @@ -435 ,7 +436 ,7 @@ def SPIRV_ExtensionAttr :
80
+ @@ -434 ,7 +435 ,7 @@ def SPIRV_ExtensionAttr :
81
81
SPV_KHR_fragment_shader_barycentric, SPV_KHR_ray_cull_mask,
82
82
SPV_KHR_uniform_group_instructions, SPV_KHR_subgroup_rotate,
83
83
SPV_KHR_non_semantic_info, SPV_KHR_terminate_invocation,
@@ -86,7 +86,7 @@ index 04952dd1dc61..6c9c348490ab 100644
86
86
SPV_EXT_demote_to_helper_invocation, SPV_EXT_descriptor_indexing,
87
87
SPV_EXT_fragment_fully_covered, SPV_EXT_fragment_invocation_density,
88
88
SPV_EXT_fragment_shader_interlock, SPV_EXT_physical_storage_buffer,
89
- @@ -1193 ,6 +1194,22 @@ def SPIRV_C_ShaderClockKHR : I32EnumAttrCase<"Shade
89
+ @@ -1192 ,6 +1193,24 @@ def SPIRV_C_ShaderClockKHR : I32EnumAttrCase<"Shade
90
90
Extension<[SPV_KHR_shader_clock]>
91
91
];
92
92
}
@@ -95,11 +95,13 @@ index 04952dd1dc61..6c9c348490ab 100644
95
95
+ Extension<[SPV_KHR_bfloat16]>
96
96
+ ];
97
97
+ }
98
+ +
98
99
+ def SPIRV_C_BFloat16DotProductKHR : I32EnumAttrCase<"BFloat16DotProductKHR", 5117> {
99
100
+ list<I32EnumAttrCase> implies = [SPIRV_C_BFloat16TypeKHR];
100
101
+ list<Availability> availability = [
101
102
+ Extension<[SPV_KHR_bfloat16]> ];
102
103
+ }
104
+ +
103
105
+ def SPIRV_C_BFloat16CooperativeMatrixKHR : I32EnumAttrCase<"BFloat16CooperativeMatrixKHR", 5118> {
104
106
+ list<I32EnumAttrCase> implies = [SPIRV_C_BFloat16TypeKHR, SPIRV_C_CooperativeMatrixKHR];
105
107
+ list<Availability> availability = [
@@ -109,15 +111,15 @@ index 04952dd1dc61..6c9c348490ab 100644
109
111
def SPIRV_C_FragmentFullyCoveredEXT : I32EnumAttrCase<"FragmentFullyCoveredEXT", 5265> {
110
112
list<I32EnumAttrCase> implies = [SPIRV_C_Shader];
111
113
list<Availability> availability = [
112
- @@ -1491 ,6 +1508 ,7 @@ def SPIRV_CapabilityAttr :
114
+ @@ -1484 ,6 +1503 ,7 @@ def SPIRV_CapabilityAttr :
113
115
SPIRV_C_RayQueryKHR, SPIRV_C_RayTracingKHR, SPIRV_C_Float16ImageAMD,
114
116
SPIRV_C_ImageGatherBiasLodAMD, SPIRV_C_FragmentMaskAMD, SPIRV_C_StencilExportEXT,
115
117
SPIRV_C_ImageReadWriteLodAMD, SPIRV_C_Int64ImageEXT, SPIRV_C_ShaderClockKHR,
116
118
+ SPIRV_C_BFloat16TypeKHR, SPIRV_C_BFloat16DotProductKHR, SPIRV_C_BFloat16CooperativeMatrixKHR,
117
119
SPIRV_C_FragmentFullyCoveredEXT, SPIRV_C_MeshShadingNV, SPIRV_C_FragmentDensityEXT,
118
120
SPIRV_C_ShaderNonUniform, SPIRV_C_RuntimeDescriptorArray,
119
121
SPIRV_C_StorageTexelBufferArrayDynamicIndexing, SPIRV_C_RayTracingNV,
120
- @@ -4148 ,16 +4166 ,21 @@ def SPIRV_Bool : TypeAlias<I1, "bool">;
122
+ @@ -4139 ,16 +4159 ,21 @@ def SPIRV_Bool : TypeAlias<I1, "bool">;
121
123
def SPIRV_Integer : AnyIntOfWidths<[8, 16, 32, 64]>;
122
124
def SPIRV_Int16 : TypeAlias<I16, "Int16">;
123
125
def SPIRV_Int32 : TypeAlias<I32, "Int32">;
@@ -142,36 +144,34 @@ index 04952dd1dc61..6c9c348490ab 100644
142
144
// Component type check is done in the type parser for the following SPIR-V
143
145
// dialect-specific types so we use "Any" here.
144
146
def SPIRV_AnyPtr : DialectType<SPIRV_Dialect, SPIRV_IsPtrType,
145
- @@ -4180 ,14 +4203 ,14 @@ def SPIRV_AnyStruct : DialectType<SPIRV_Dialect, SPIRV_IsStructType,
147
+ @@ -4169 ,14 +4194 ,14 @@ def SPIRV_AnyStruct : DialectType<SPIRV_Dialect, SPIRV_IsStructType,
146
148
def SPIRV_AnySampledImage : DialectType<SPIRV_Dialect, SPIRV_IsSampledImageType,
147
149
"any SPIR-V sampled image type">;
148
-
150
+
149
151
- def SPIRV_Numerical : AnyTypeOf<[SPIRV_Integer, SPIRV_Float]>;
150
152
+ def SPIRV_Numerical : AnyTypeOf<[SPIRV_Integer, SPIRV_Float, SPIRV_BFloat16KHR]>;
151
153
def SPIRV_Scalar : AnyTypeOf<[SPIRV_Numerical, SPIRV_Bool]>;
152
154
def SPIRV_Aggregate : AnyTypeOf<[SPIRV_AnyArray, SPIRV_AnyRTArray, SPIRV_AnyStruct]>;
153
155
def SPIRV_Composite :
154
156
AnyTypeOf<[SPIRV_Vector, SPIRV_AnyArray, SPIRV_AnyRTArray, SPIRV_AnyStruct,
155
- SPIRV_AnyCooperativeMatrix, SPIRV_AnyJointMatrix, SPIRV_AnyMatrix]>;
157
+ SPIRV_AnyCooperativeMatrix, SPIRV_AnyMatrix]>;
156
158
def SPIRV_Type : AnyTypeOf<[
157
159
- SPIRV_Void, SPIRV_Bool, SPIRV_Integer, SPIRV_Float, SPIRV_Vector,
158
160
+ SPIRV_Void, SPIRV_Bool, SPIRV_Integer, SPIRV_Float, SPIRV_BFloat16KHR, SPIRV_Vector,
159
161
SPIRV_AnyPtr, SPIRV_AnyArray, SPIRV_AnyRTArray, SPIRV_AnyStruct,
160
- SPIRV_AnyCooperativeMatrix, SPIRV_AnyJointMatrix, SPIRV_AnyMatrix,
161
- SPIRV_AnySampledImage
162
- @@ -4764,6 +4787,12 @@ def SPIRV_FPFMM_AllowReassocINTEL : I32BitEnumAttrCaseBit<"AllowReassocINTEL", 1
163
- ];
164
- }
165
-
162
+ SPIRV_AnyCooperativeMatrix, SPIRV_AnyMatrix, SPIRV_AnySampledImage
163
+ ]>;
164
+ @@ -4745,4 +4770,10 @@ def SPIRV_FPFastMathModeAttr :
165
+ SPIRV_FPFMM_AllowReassocINTEL
166
+ ]>;
167
+
166
168
+ def SPIRV_FPE_BFloat16KHR : I32EnumAttrCase<"BFloat16KHR", 0>;
167
169
+ def SPIRV_FP_Encoding :
168
170
+ SPIRV_I32EnumAttr<"FPEncoding", "Valid floating-point encoding", "fpEncoding", [
169
171
+ SPIRV_FPE_BFloat16KHR
170
172
+ ]>;
171
173
+
172
- def SPIRV_FPFastMathModeAttr :
173
- SPIRV_BitEnumAttr<"FPFastMathMode", "Indicates a floating-point fast math flag", "fastmath_mode", [
174
- SPIRV_FPFMM_None, SPIRV_FPFMM_NotNaN, SPIRV_FPFMM_NotInf, SPIRV_FPFMM_NSZ,
174
+ #endif // MLIR_DIALECT_SPIRV_IR_BASE
175
175
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCLOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCLOps.td
176
176
index b5ca27d7d753..703920e42c60 100644
177
177
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCLOps.td
0 commit comments