Skip to content

Commit e9918c1

Browse files
committed
Add LLVM patch to add MLIR Intel spirv bf16<->f32 conversion ops.
1 parent f5af018 commit e9918c1

File tree

2 files changed

+347
-0
lines changed

2 files changed

+347
-0
lines changed

.github/workflows/build.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,8 @@ jobs:
146146
git clone https://github.com/llvm/llvm-project --branch main --single-branch || exit 1
147147
cd llvm-project || exit 1
148148
git checkout ${LLVM_SHA} || exit 1
149+
/home/runner/work/mlir-extensions/mlir-extensions
150+
if [ -d "/home/runner/work/mlir-extensions/mlir-extensions/build_tools/patches" ]; then git apply /home/runner/work/mlir-extensions/mlir-extensions/build_tools/patches/*.patch; fi
149151
mkdir _build || exit 1
150152
cd _build || exit 1
151153
cmake ../llvm \
Lines changed: 345 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,345 @@
1+
From 036fb84a2ff8d499c07e270436599d7c35f8de6a Mon Sep 17 00:00:00 2001
2+
From: Md Abdullah Shahneous Bari <[email protected]>
3+
Date: Fri, 27 Jan 2023 22:50:59 +0000
4+
Subject: [PATCH] Add OpExtension "SPV_INTEL_bfloat16_conversion"
5+
6+
Add Intel-specific "SPV_INTEL_bfloat16_conversion" extension and
7+
capability (Bfloat16ConversionINTEL), and
8+
two ops (OpConvertFToBF16INTEL, OpConvertBF16ToFINTEL)
9+
that are introduced by this extension.
10+
---
11+
.../mlir/Dialect/SPIRV/IR/SPIRVBase.td | 19 ++-
12+
.../mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td | 130 ++++++++++++++++++
13+
.../include/mlir/Dialect/SPIRV/IR/SPIRVOps.td | 1 +
14+
mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp | 40 ++++++
15+
mlir/test/Target/SPIRV/intel-ext-ops.mlir | 45 ++++++
16+
5 files changed, 232 insertions(+), 3 deletions(-)
17+
create mode 100644 mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td
18+
create mode 100644 mlir/test/Target/SPIRV/intel-ext-ops.mlir
19+
20+
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
21+
index 87aa084c8783..40875c0892b6 100644
22+
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
23+
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
24+
@@ -399,6 +399,7 @@ def SPV_INTEL_fp_fast_math_mode : I32EnumAttrCase<"SPV_INTEL_fp
25+
def SPV_INTEL_memory_access_aliasing : I32EnumAttrCase<"SPV_INTEL_memory_access_aliasing", 4028>;
26+
def SPV_INTEL_split_barrier : I32EnumAttrCase<"SPV_INTEL_split_barrier", 4029>;
27+
def SPV_INTEL_joint_matrix : I32EnumAttrCase<"SPV_INTEL_joint_matrix", 4030>;
28+
+def SPV_INTEL_bfloat16_conversion : I32EnumAttrCase<"SPV_INTEL_bfloat16_conversion", 4031>;
29+
30+
def SPV_NV_compute_shader_derivatives : I32EnumAttrCase<"SPV_NV_compute_shader_derivatives", 5000>;
31+
def SPV_NV_cooperative_matrix : I32EnumAttrCase<"SPV_NV_cooperative_matrix", 5001>;
32+
@@ -457,7 +458,7 @@ def SPIRV_ExtensionAttr :
33+
SPV_INTEL_fpga_reg, SPV_INTEL_long_constant_composite, SPV_INTEL_optnone,
34+
SPV_INTEL_debug_module, SPV_INTEL_fp_fast_math_mode,
35+
SPV_INTEL_memory_access_aliasing, SPV_INTEL_split_barrier, SPV_INTEL_joint_matrix,
36+
- SPV_NV_compute_shader_derivatives, SPV_NV_cooperative_matrix,
37+
+ SPV_INTEL_bfloat16_conversion, SPV_NV_compute_shader_derivatives, SPV_NV_cooperative_matrix,
38+
SPV_NV_fragment_shader_barycentric, SPV_NV_geometry_shader_passthrough,
39+
SPV_NV_mesh_shader, SPV_NV_ray_tracing, SPV_NV_sample_mask_override_coverage,
40+
SPV_NV_shader_image_footprint, SPV_NV_shader_sm_builtins,
41+
@@ -1413,6 +1414,12 @@ def SPIRV_C_JointMatrixINTEL : I32EnumAttrCase<"JointMat
42+
];
43+
}
44+
45+
+def SPIRV_C_Bfloat16ConversionINTEL : I32EnumAttrCase<"Bfloat16ConversionINTEL", 6115> {
46+
+ list<Availability> availability = [
47+
+ Extension<[SPV_INTEL_bfloat16_conversion]>
48+
+ ];
49+
+}
50+
+
51+
def SPIRV_CapabilityAttr :
52+
SPIRV_I32EnumAttr<"Capability", "valid SPIR-V Capability", "capability", [
53+
SPIRV_C_Matrix, SPIRV_C_Addresses, SPIRV_C_Linkage, SPIRV_C_Kernel, SPIRV_C_Float16,
54+
@@ -1504,7 +1511,7 @@ def SPIRV_CapabilityAttr :
55+
SPIRV_C_UniformTexelBufferArrayNonUniformIndexing,
56+
SPIRV_C_StorageTexelBufferArrayNonUniformIndexing,
57+
SPIRV_C_ShaderViewportIndexLayerEXT, SPIRV_C_ShaderViewportMaskNV,
58+
- SPIRV_C_ShaderStereoViewNV, SPIRV_C_JointMatrixINTEL
59+
+ SPIRV_C_ShaderStereoViewNV, SPIRV_C_JointMatrixINTEL, SPIRV_C_Bfloat16ConversionINTEL
60+
]>;
61+
62+
def SPIRV_AM_Logical : I32EnumAttrCase<"Logical", 0>;
63+
@@ -4079,6 +4086,7 @@ def SPIRV_IsStructType : CPred<"$_self.isa<::mlir::spirv::StructType>()">;
64+
def SPIRV_Void : TypeAlias<NoneType, "void">;
65+
def SPIRV_Bool : TypeAlias<I1, "bool">;
66+
def SPIRV_Integer : AnyIntOfWidths<[8, 16, 32, 64]>;
67+
+def SPIRV_Int16 : TypeAlias<I16, "Int16">;
68+
def SPIRV_Int32 : TypeAlias<I32, "Int32">;
69+
def SPIRV_Float32 : TypeAlias<F32, "Float32">;
70+
def SPIRV_Float : FloatOfWidths<[16, 32, 64]>;
71+
@@ -4412,6 +4420,9 @@ def SPIRV_OC_OpJointMatrixStoreINTEL : I32EnumAttrCase<"OpJointMatrixStoreI
72+
def SPIRV_OC_OpJointMatrixMadINTEL : I32EnumAttrCase<"OpJointMatrixMadINTEL", 6122>;
73+
def SPIRV_OC_OpTypejointMatrixWorkItemLengthINTEL : I32EnumAttrCase<"OpJointMatrixWorkItemLengthINTEL", 6410>;
74+
75+
+def SPIRV_OC_OpConvertFToBF16INTEL : I32EnumAttrCase<"OpConvertFToBF16INTEL", 6116>;
76+
+def SPIRV_OC_OpConvertBF16ToFINTEL : I32EnumAttrCase<"OpConvertBF16ToFINTEL", 6117>;
77+
+
78+
def SPIRV_OpcodeAttr :
79+
SPIRV_I32EnumAttr<"Opcode", "valid SPIR-V instructions", "opcode", [
80+
SPIRV_OC_OpNop, SPIRV_OC_OpUndef, SPIRV_OC_OpSourceContinued,
81+
@@ -4497,7 +4508,9 @@ def SPIRV_OpcodeAttr :
82+
83+
SPIRV_OC_OpTypeJointMatrixINTEL, SPIRV_OC_OpJointMatrixLoadINTEL,
84+
SPIRV_OC_OpJointMatrixStoreINTEL, SPIRV_OC_OpJointMatrixMadINTEL,
85+
- SPIRV_OC_OpTypejointMatrixWorkItemLengthINTEL
86+
+ SPIRV_OC_OpTypejointMatrixWorkItemLengthINTEL,
87+
+
88+
+ SPIRV_OC_OpConvertFToBF16INTEL, SPIRV_OC_OpConvertBF16ToFINTEL
89+
]>;
90+
91+
// End opcode section. Generated from SPIR-V spec; DO NOT MODIFY!
92+
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td
93+
new file mode 100644
94+
index 000000000000..a02f093aa50b
95+
--- /dev/null
96+
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td
97+
@@ -0,0 +1,130 @@
98+
+//===- SPIRVIntelExtOps.td - Intel SPIR-V extensions ---------------*- tablegen -*-===//
99+
+//
100+
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
101+
+// See https://llvm.org/LICENSE.txt for license information.
102+
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
103+
+//
104+
+//===----------------------------------------------------------------------===//
105+
+//
106+
+// This is the op definition spec of Intel-specific SPIR-V extensions
107+
+// These extensions are not part of Khronos specification but publicly available
108+
+// at (https://github.com/intel/llvm)
109+
+// Supported extensions
110+
+// * SPV_INTEL_bfloat16_conversion
111+
+//===----------------------------------------------------------------------===//
112+
+
113+
+
114+
+#ifndef MLIR_DIALECT_SPIRV_IR_INTEL_EXT_OPS
115+
+#define MLIR_DIALECT_SPIRV_IR_INTEL_EXT_OPS
116+
+
117+
+// -----
118+
+
119+
+def SPIRV_INTELConvertFToBF16Op : SPIRV_IntelVendorOp<"ConvertFToBF16", []> {
120+
+ let summary = "See extension SPV_INTEL_bfloat16_conversion";
121+
+
122+
+ let description = [{
123+
+ Convert value numerically from 32-bit floating point to bfloat16,
124+
+ which is represented as a 16-bit unsigned integer.
125+
+
126+
+ Result Type must be a scalar or vector of integer type.
127+
+ The component width must be 16 bits. Bit pattern in the Result represents a bfloat16 value.
128+
+
129+
+ Float Value must be a scalar or vector of floating-point type.
130+
+ It must have the same number of components as Result Type. The component width must be 32 bits.
131+
+
132+
+ Results are computed per component.
133+
+
134+
+ ```
135+
+ integer16-scalar-vector-type ::= integer16-type |
136+
+ `vector<` integer-literal `x` integer16-type `>`
137+
+ ConvertFToBF16-op ::= ssa-id `=` `spirv.INTEL.ConvertFToBF16` ssa-use
138+
+ `:` integer16-scalar-vector-type
139+
+ ```
140+
+
141+
+ #### Example:
142+
+
143+
+ ```mlir
144+
+ %2 = spirv.INTEL.ConvertFToBF16 %0 : i16
145+
+ %3 = spirv.INTEL.ConvertFToBF16 %1 : vector<4xi16>
146+
+
147+
+ ```
148+
+ }];
149+
+
150+
+
151+
+ let availability = [
152+
+ MinVersion<SPIRV_V_1_0>,
153+
+ MaxVersion<SPIRV_V_1_6>,
154+
+ Extension<[SPV_INTEL_bfloat16_conversion]>,
155+
+ Capability<[SPIRV_C_Bfloat16ConversionINTEL]>
156+
+ ];
157+
+
158+
+ let arguments = (ins
159+
+ SPIRV_ScalarOrVectorOf<SPIRV_Float32>:$operand
160+
+ );
161+
+
162+
+ let results = (outs
163+
+ SPIRV_ScalarOrVectorOf<SPIRV_Int16>:$result
164+
+ );
165+
+ let assemblyFormat = [{
166+
+ $operand attr-dict `:` type($operand) `to` type($result)
167+
+ }];
168+
+
169+
+ let hasVerifier = 1;
170+
+}
171+
+
172+
+// -----
173+
+
174+
+def SPIRV_INTELConvertBF16ToFOp : SPIRV_IntelVendorOp<"ConvertBF16ToF", []> {
175+
+ let summary = "See extension SPV_INTEL_bfloat16_conversion";
176+
+
177+
+ let description = [{
178+
+ Interpret a 16-bit integer as bfloat16 and convert the value numerically to 32-bit floating point type.
179+
+
180+
+ Result Type must be a scalar or vector of floating-point. The component width must be 32 bits.
181+
+
182+
+ Bfloat16 Value must be a scalar or vector of integer type, which is interpreted as a bfloat16 type.
183+
+ The type must have the same number of components as the Result Type. The component width must be 16 bits.
184+
+
185+
+ Results are computed per component.
186+
+
187+
+ ```
188+
+ float-scalar-vector-type ::= integer16-type |
189+
+ `vector<` integer-literal `x` integer16-type `>`
190+
+ ConvertFToBF16-op ::= ssa-id `=` `spirv.INTEL.ConvertBF16ToF` ssa-use
191+
+ `:` float-scalar-vector-type
192+
+ ```
193+
+
194+
+ #### Example:
195+
+
196+
+ ```mlir
197+
+ %2 = spirv.INTEL.ConvertBF16ToF %0 : f32
198+
+ %3 = spirv.INTEL.ConvertBF16ToF %1 : vector<4xf32>
199+
+
200+
+ ```
201+
+ }];
202+
+
203+
+ let availability = [
204+
+ MinVersion<SPIRV_V_1_0>,
205+
+ MaxVersion<SPIRV_V_1_6>,
206+
+ Extension<[SPV_INTEL_bfloat16_conversion]>,
207+
+ Capability<[SPIRV_C_Bfloat16ConversionINTEL]>
208+
+ ];
209+
+
210+
+ let arguments = (ins
211+
+ SPIRV_ScalarOrVectorOf<SPIRV_Int16>:$operand
212+
+ );
213+
+
214+
+ let results = (outs
215+
+ SPIRV_ScalarOrVectorOf<SPIRV_Float32>:$result
216+
+ );
217+
+
218+
+ let assemblyFormat = [{
219+
+ $operand attr-dict `:` type($operand) `to` type($result)
220+
+ }];
221+
+ let hasVerifier = 1;
222+
+}
223+
+
224+
+
225+
+// -----
226+
+
227+
+#endif // MLIR_DIALECT_SPIRV_IR_INTEL_EXT_OPS
228+
\ No newline at end of file
229+
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.td
230+
index 767e939f0447..13533d1d65b8 100644
231+
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.td
232+
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.td
233+
@@ -31,6 +31,7 @@ include "mlir/Dialect/SPIRV/IR/SPIRVCompositeOps.td"
234+
include "mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td"
235+
include "mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td"
236+
include "mlir/Dialect/SPIRV/IR/SPIRVJointMatrixOps.td"
237+
+include "mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td"
238+
include "mlir/Dialect/SPIRV/IR/SPIRVGLOps.td"
239+
include "mlir/Dialect/SPIRV/IR/SPIRVGroupOps.td"
240+
include "mlir/Dialect/SPIRV/IR/SPIRVImageOps.td"
241+
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
242+
index 2a11c4c6a9dd..96e96ea5b38c 100644
243+
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
244+
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
245+
@@ -2221,6 +2221,46 @@ LogicalResult spirv::ConvertUToFOp::verify() {
246+
/*skipBitWidthCheck=*/true);
247+
}
248+
249+
+//===----------------------------------------------------------------------===//
250+
+// spirv.INTELConvertBF16ToFOp
251+
+//===----------------------------------------------------------------------===//
252+
+
253+
+LogicalResult spirv::INTELConvertBF16ToFOp::verify() {
254+
+ auto operandType = getOperand().getType();
255+
+ auto resultType = getResult().getType();
256+
+ // ODS checks that vector result type and vector operand type have the same
257+
+ // shape.
258+
+ if (auto vectorType = operandType.dyn_cast<VectorType>()) {
259+
+ unsigned operandNumElements = vectorType.getNumElements();
260+
+ unsigned resultNumElements = resultType.cast<VectorType>().getNumElements();
261+
+ if (operandNumElements != resultNumElements) {
262+
+ return emitOpError(
263+
+ "operand and result must have same number of elements");
264+
+ }
265+
+ }
266+
+ return success();
267+
+}
268+
+
269+
+//===----------------------------------------------------------------------===//
270+
+// spirv.INTELConvertFToBF16Op
271+
+//===----------------------------------------------------------------------===//
272+
+
273+
+LogicalResult spirv::INTELConvertFToBF16Op::verify() {
274+
+ auto operandType = getOperand().getType();
275+
+ auto resultType = getResult().getType();
276+
+ // ODS checks that vector result type and vector operand type have the same
277+
+ // shape.
278+
+ if (auto vectorType = operandType.dyn_cast<VectorType>()) {
279+
+ unsigned operandNumElements = vectorType.getNumElements();
280+
+ unsigned resultNumElements = resultType.cast<VectorType>().getNumElements();
281+
+ if (operandNumElements != resultNumElements) {
282+
+ return emitOpError(
283+
+ "operand and result must have same number of elements");
284+
+ }
285+
+ }
286+
+ return success();
287+
+}
288+
+
289+
//===----------------------------------------------------------------------===//
290+
// spirv.EntryPoint
291+
//===----------------------------------------------------------------------===//
292+
diff --git a/mlir/test/Target/SPIRV/intel-ext-ops.mlir b/mlir/test/Target/SPIRV/intel-ext-ops.mlir
293+
new file mode 100644
294+
index 000000000000..8e19f95364e6
295+
--- /dev/null
296+
+++ b/mlir/test/Target/SPIRV/intel-ext-ops.mlir
297+
@@ -0,0 +1,45 @@
298+
+// RUN: mlir-translate -no-implicit-module -test-spirv-roundtrip -split-input-file %s | FileCheck %s
299+
+
300+
+spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Bfloat16ConversionINTEL], [SPV_INTEL_bfloat16_conversion]> {
301+
+ // CHECK-LABEL: @f32_to_bf16
302+
+ spirv.func @f32_to_bf16(%arg0 : f32) "None" {
303+
+ // CHECK: {{%.*}} = spirv.INTEL.ConvertFToBF16 {{%.*}} : f32 to i16
304+
+ %0 = spirv.INTEL.ConvertFToBF16 %arg0 : f32 to i16
305+
+ spirv.Return
306+
+ }
307+
+
308+
+ // CHECK-LABEL: @f32_to_bf16_vec
309+
+ spirv.func @f32_to_bf16_vec(%arg0 : vector<2xf32>) "None" {
310+
+ // CHECK: {{%.*}} = spirv.INTEL.ConvertFToBF16 {{%.*}} : vector<2xf32> to vector<2xi16>
311+
+ %0 = spirv.INTEL.ConvertFToBF16 %arg0 : vector<2xf32> to vector<2xi16>
312+
+ spirv.Return
313+
+ }
314+
+
315+
+ // CHECK-LABEL: @bf16_to_f32
316+
+ spirv.func @bf16_to_f32(%arg0 : i16) "None" {
317+
+ // CHECK: {{%.*}} = spirv.INTEL.ConvertBF16ToF {{%.*}} : i16 to f32
318+
+ %0 = spirv.INTEL.ConvertBF16ToF %arg0 : i16 to f32
319+
+ spirv.Return
320+
+ }
321+
+
322+
+ // CHECK-LABEL: @bf16_to_f32_vec
323+
+ spirv.func @bf16_to_f32_vec(%arg0 : vector<2xi16>) "None" {
324+
+ // CHECK: {{%.*}} = spirv.INTEL.ConvertBF16ToF {{%.*}} : vector<2xi16> to vector<2xf32>
325+
+ %0 = spirv.INTEL.ConvertBF16ToF %arg0 : vector<2xi16> to vector<2xf32>
326+
+ spirv.Return
327+
+ }
328+
+
329+
+ // // CHECK-LABEL: @f32_to_bf16_unsupported
330+
+ // spirv.func @f32_to_bf16_unsupported(%arg0 : f64) "None" {
331+
+ // // CHECK: {{%.*}} = spirv.INTEL.ConvertFToBF16 {{%.*}} : f64 to i16
332+
+ // %0 = spirv.INTEL.ConvertFToBF16 %arg0 : f64 to i16
333+
+ // spirv.Return
334+
+ // }
335+
+
336+
+ // CHECK-LABEL: @bf16_to_f32_vec_unsupported
337+
+ spirv.func @bf16_to_f32_vec_unsupported(%arg0 : vector<2xi16>) "None" {
338+
+ // CHECK: {{%.*}} = spirv.INTEL.ConvertBF16ToF {{%.*}} : vector<2xi16> to vector<2xf32>
339+
+ %0 = spirv.INTEL.ConvertBF16ToF %arg0 : vector<2xi16> to vector<3xf32>
340+
+ spirv.Return
341+
+ }
342+
+}
343+
\ No newline at end of file
344+
--
345+
2.25.1

0 commit comments

Comments
 (0)