Skip to content

Commit 408bcb1

Browse files
authored
Merge pull request #827 from mshahneo/bf16_native_support_for_upsream
[SPIR-V][test][llvm][patch] Add native bf16 support in SPIR-V dialect…
2 parents 2979612 + ed45c49 commit 408bcb1

11 files changed

+1984
-4
lines changed

build_tools/patches/0014-SPIR-V-Enable-native-bf16-support-in-SPIR-V-dialect.patch

Lines changed: 440 additions & 0 deletions
Large diffs are not rendered by default.

lib/Transforms/SetSPIRVCapabilities.cpp

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,11 +54,12 @@ struct SetSPIRVCapabilitiesPass
5454
spirv::Capability caps_opencl[] = {
5555
// clang-format off
5656
spirv::Capability::Addresses,
57+
spirv::Capability::Bfloat16ConversionINTEL,
58+
spirv::Capability::BFloat16TypeKHR,
5759
spirv::Capability::Float16Buffer,
5860
spirv::Capability::Int64,
5961
spirv::Capability::Int16,
6062
spirv::Capability::Int8,
61-
spirv::Capability::Bfloat16ConversionINTEL,
6263
spirv::Capability::Kernel,
6364
spirv::Capability::Linkage,
6465
spirv::Capability::Vector16,
@@ -77,10 +78,14 @@ struct SetSPIRVCapabilitiesPass
7778
// clang-format on
7879
};
7980
spirv::Extension exts_opencl[] = {
80-
spirv::Extension::SPV_INTEL_bfloat16_conversion,
81+
// clang-format off
8182
spirv::Extension::SPV_EXT_shader_atomic_float_add,
83+
spirv::Extension::SPV_KHR_bfloat16,
8284
spirv::Extension::SPV_KHR_expect_assume,
83-
spirv::Extension::SPV_INTEL_vector_compute};
85+
spirv::Extension::SPV_INTEL_bfloat16_conversion,
86+
spirv::Extension::SPV_INTEL_vector_compute
87+
// clang-format on
88+
};
8489
spirv::Extension exts_vulkan[] = {
8590
spirv::Extension::SPV_KHR_storage_buffer_storage_class};
8691
if (m_clientAPI == "opencl") {
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/gpu-to-llvm.pp \
2+
// RUN: --runner imex-cpu-runner -e main \
3+
// RUN: --entry-point-result=void \
4+
// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck
5+
// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/gpu-to-llvm.pp \
6+
// RUN: --runner imex-cpu-runner -e main \
7+
// RUN: --entry-point-result=void \
8+
// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck
9+
10+
module @eltwise_add attributes {gpu.container_module} {
11+
memref.global "private" constant @__constant_10x20xbf16 : memref<10x20xbf16> = dense<5.000000e-01>
12+
func.func @test(%arg0: memref<10x20xbf16>, %arg1: memref<10x20xbf16>) -> memref<10x20xbf16> {
13+
%c20 = arith.constant 20 : index
14+
%c10 = arith.constant 10 : index
15+
%c1 = arith.constant 1 : index
16+
%memref = gpu.alloc host_shared () : memref<10x20xbf16>
17+
memref.copy %arg1, %memref : memref<10x20xbf16> to memref<10x20xbf16>
18+
%memref_0 = gpu.alloc host_shared () : memref<10x20xbf16>
19+
memref.copy %arg0, %memref_0 : memref<10x20xbf16> to memref<10x20xbf16>
20+
%memref_1 = gpu.alloc host_shared () : memref<10x20xbf16>
21+
gpu.launch_func @test_kernel::@test_kernel blocks in (%c10, %c20, %c1) threads in (%c1, %c1, %c1) args(%memref_0 : memref<10x20xbf16>, %memref : memref<10x20xbf16>, %memref_1 : memref<10x20xbf16>)
22+
%alloc = memref.alloc() : memref<10x20xbf16>
23+
memref.copy %memref_1, %alloc : memref<10x20xbf16> to memref<10x20xbf16>
24+
gpu.dealloc %memref_1 : memref<10x20xbf16>
25+
gpu.dealloc %memref_0 : memref<10x20xbf16>
26+
gpu.dealloc %memref : memref<10x20xbf16>
27+
return %alloc : memref<10x20xbf16>
28+
}
29+
gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Addresses, Float16Buffer, Int64, Int16, Int8, Bfloat16ConversionINTEL, Kernel, Linkage, Vector16, GenericPointer, Groups, Float16, Float64, AtomicFloat32AddEXT, ExpectAssumeKHR, VectorAnyINTEL, BFloat16TypeKHR], [SPV_INTEL_bfloat16_conversion, SPV_EXT_shader_atomic_float_add, SPV_KHR_expect_assume, SPV_INTEL_vector_compute, SPV_KHR_bfloat16]>, api=OpenCL, #spirv.resource_limits<>>} {
30+
gpu.func @test_kernel(%arg0: memref<10x20xbf16>, %arg1: memref<10x20xbf16>, %arg2: memref<10x20xbf16>) kernel attributes {VectorComputeFunctionINTEL, gpu.known_block_size = array<i32: 1, 1, 1>, gpu.known_grid_size = array<i32: 10, 20, 1>, spirv.entry_point_abi = #spirv.entry_point_abi<>} {
31+
%block_id_x = gpu.block_id x
32+
%block_id_y = gpu.block_id y
33+
%0 = memref.load %arg0[%block_id_x, %block_id_y] : memref<10x20xbf16>
34+
%1 = memref.load %arg1[%block_id_x, %block_id_y] : memref<10x20xbf16>
35+
%2 = arith.addf %0, %1 : bf16
36+
memref.store %2, %arg2[%block_id_x, %block_id_y] : memref<10x20xbf16>
37+
gpu.return
38+
}
39+
}
40+
func.func @main() {
41+
%0 = memref.get_global @__constant_10x20xbf16 : memref<10x20xbf16>
42+
%1 = memref.get_global @__constant_10x20xbf16 : memref<10x20xbf16>
43+
%2 = call @test(%0, %1) : (memref<10x20xbf16>, memref<10x20xbf16>) -> memref<10x20xbf16>
44+
%cast = memref.cast %2 : memref<10x20xbf16> to memref<*xbf16>
45+
// CHECK: Unranked Memref base@ = {{(0x)?[-9a-f]*}}
46+
// CHECK-COUNT-200: 1
47+
call @printMemrefBF16(%cast) : (memref<*xbf16>) -> ()
48+
return
49+
}
50+
func.func private @printMemrefBF16(memref<*xbf16>) attributes {llvm.emit_c_interface}
51+
}
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
// gpu dialect with intel intrinsic functions (func dialect) to
2+
// llvm dialect (for host code) and
3+
// spirv dialect (for device code) lowering pipeline.
4+
// Ready for imex runner starting from GPU dialect.
5+
builtin.module(
6+
imex-vector-linearize
7+
reconcile-unrealized-casts
8+
imex-convert-gpu-to-spirv
9+
spirv.module(spirv-lower-abi-attrs
10+
spirv-update-vce)
11+
func.func(llvm-request-c-wrappers)
12+
serialize-spirv
13+
convert-vector-to-scf
14+
convert-gpu-to-gpux
15+
convert-scf-to-cf
16+
convert-cf-to-llvm
17+
convert-vector-to-llvm
18+
convert-index-to-llvm
19+
convert-arith-to-llvm
20+
convert-func-to-llvm
21+
convert-math-to-llvm
22+
convert-gpux-to-llvm
23+
convert-index-to-llvm
24+
expand-strided-metadata
25+
lower-affine
26+
finalize-memref-to-llvm
27+
reconcile-unrealized-casts)
28+
// End

0 commit comments

Comments
 (0)