Skip to content

Commit 86340d5

Browse files
MrSidimssys-ce-bb
authored andcommitted
Support for SPV_INTEL_shader_atomic_bfloat16 extension (#3343)
Spec is available here: #20009 Author: "Ratajewski, Andrzej" <[email protected]> Signed-off-by: Sidorov, Dmitry <[email protected]> Original commit: KhronosGroup/SPIRV-LLVM-Translator@8e8c02c4b803062
1 parent ac31b7a commit 86340d5

File tree

10 files changed

+209
-3
lines changed

10 files changed

+209
-3
lines changed

llvm-spirv/include/LLVMSPIRVExtensions.inc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,3 +80,4 @@ EXT(SPV_INTEL_bfloat16_arithmetic)
8080
EXT(SPV_INTEL_ternary_bitwise_function)
8181
EXT(SPV_INTEL_int4)
8282
EXT(SPV_INTEL_function_variants)
83+
EXT(SPV_INTEL_shader_atomic_bfloat16)

llvm-spirv/lib/SPIRV/libSPIRV/SPIRVInstruction.h

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3019,40 +3019,48 @@ class SPIRVAtomicFAddEXTInst : public SPIRVAtomicInstBase {
30193019
public:
30203020
std::optional<ExtensionID> getRequiredExtension() const override {
30213021
assert(hasType());
3022+
if (getType()->isTypeFloat(16, FPEncodingBFloat16KHR))
3023+
return ExtensionID::SPV_INTEL_shader_atomic_bfloat16;
30223024
if (getType()->isTypeFloat(16))
30233025
return ExtensionID::SPV_EXT_shader_atomic_float16_add;
30243026
return ExtensionID::SPV_EXT_shader_atomic_float_add;
30253027
}
30263028

30273029
SPIRVCapVec getRequiredCapability() const override {
30283030
assert(hasType());
3031+
if (getType()->isTypeFloat(16, FPEncodingBFloat16KHR))
3032+
return {internal::CapabilityAtomicBFloat16AddINTEL};
30293033
if (getType()->isTypeFloat(16))
30303034
return {CapabilityAtomicFloat16AddEXT};
30313035
if (getType()->isTypeFloat(32))
30323036
return {CapabilityAtomicFloat32AddEXT};
30333037
if (getType()->isTypeFloat(64))
30343038
return {CapabilityAtomicFloat64AddEXT};
30353039
llvm_unreachable(
3036-
"AtomicFAddEXT can only be generated for f16, f32, f64 types");
3040+
"AtomicFAddEXT can only be generated for bf16, f16, f32, f64 types");
30373041
}
30383042
};
30393043

30403044
class SPIRVAtomicFMinMaxEXTBase : public SPIRVAtomicInstBase {
30413045
public:
30423046
std::optional<ExtensionID> getRequiredExtension() const override {
3047+
if (getType()->isTypeFloat(16, FPEncodingBFloat16KHR))
3048+
return ExtensionID::SPV_INTEL_shader_atomic_bfloat16;
30433049
return ExtensionID::SPV_EXT_shader_atomic_float_min_max;
30443050
}
30453051

30463052
SPIRVCapVec getRequiredCapability() const override {
30473053
assert(hasType());
3054+
if (getType()->isTypeFloat(16, FPEncodingBFloat16KHR))
3055+
return {internal::CapabilityAtomicBFloat16MinMaxINTEL};
30483056
if (getType()->isTypeFloat(16))
30493057
return {CapabilityAtomicFloat16MinMaxEXT};
30503058
if (getType()->isTypeFloat(32))
30513059
return {CapabilityAtomicFloat32MinMaxEXT};
30523060
if (getType()->isTypeFloat(64))
30533061
return {CapabilityAtomicFloat64MinMaxEXT};
3054-
llvm_unreachable(
3055-
"AtomicF(Min|Max)EXT can only be generated for f16, f32, f64 types");
3062+
llvm_unreachable("AtomicF(Min|Max)EXT can only be generated for bf16, f16, "
3063+
"f32, f64 types");
30563064
}
30573065
};
30583066

llvm-spirv/lib/SPIRV/libSPIRV/SPIRVNameMapEnum.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -638,6 +638,9 @@ template <> inline void SPIRVMap<Capability, std::string>::init() {
638638
add(CapabilityLongCompositesINTEL, "LongCompositesINTEL");
639639
add(CapabilityOptNoneINTEL, "OptNoneINTEL");
640640
add(CapabilityAtomicFloat16AddEXT, "AtomicFloat16AddEXT");
641+
add(internal::CapabilityAtomicBFloat16AddINTEL, "AtomicBFloat16AddINTEL");
642+
add(internal::CapabilityAtomicBFloat16MinMaxINTEL,
643+
"AtomicBFloat16MinMaxINTEL");
641644
add(CapabilityDebugInfoModuleINTEL, "DebugInfoModuleINTEL");
642645
add(CapabilityBFloat16ConversionINTEL, "Bfloat16ConversionINTEL");
643646
add(CapabilitySplitBarrierINTEL, "SplitBarrierINTEL");

llvm-spirv/lib/SPIRV/libSPIRV/spirv_internal.hpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,8 @@ enum InternalCapability {
109109
ICapabilityCooperativeMatrixCheckedInstructionsINTEL = 6192,
110110
ICapabilityBFloat16ArithmeticINTEL = 6226,
111111
ICapabilityCooperativeMatrixOffsetInstructionsINTEL = 6238,
112+
ICapabilityAtomicBFloat16AddINTEL = 6255,
113+
ICapabilityAtomicBFloat16MinMaxINTEL = 6256,
112114
ICapabilityCooperativeMatrixPrefetchINTEL = 6411,
113115
ICapabilityMaskedGatherScatterINTEL = 6427,
114116
ICapabilityJointMatrixWIInstructionsINTEL = 6435,
@@ -206,6 +208,9 @@ _SPIRV_OP(Capability, BindlessImagesINTEL)
206208
_SPIRV_OP(Op, ConvertHandleToImageINTEL)
207209
_SPIRV_OP(Op, ConvertHandleToSamplerINTEL)
208210
_SPIRV_OP(Op, ConvertHandleToSampledImageINTEL)
211+
212+
_SPIRV_OP(Capability, AtomicBFloat16AddINTEL)
213+
_SPIRV_OP(Capability, AtomicBFloat16MinMaxINTEL)
209214
#undef _SPIRV_OP
210215

211216
constexpr SourceLanguage SourceLanguagePython =
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
; RUN: llvm-as %s -o %t.bc
2+
; RUN: llvm-spirv %t.bc --spirv-ext=+SPV_INTEL_shader_atomic_bfloat16,+SPV_KHR_bfloat16 -o %t.spv
3+
; RUN: llvm-spirv -to-text %t.spv -o %t.spt
4+
; RUN: FileCheck < %t.spt %s --check-prefix=CHECK-SPIRV
5+
6+
; RUN: llvm-spirv --spirv-target-env=SPV-IR -r %t.spv -o %t.rev.bc
7+
; RUN: llvm-dis %t.rev.bc -o - | FileCheck %s --check-prefixes=CHECK-LLVM-SPV
8+
9+
target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64"
10+
target triple = "spir64-unknown-unknown"
11+
12+
; CHECK-SPIRV-DAG: Capability AtomicBFloat16AddINTEL
13+
; CHECK-SPIRV-DAG: Capability BFloat16TypeKHR
14+
; CHECK-SPIRV-DAG: Extension "SPV_INTEL_shader_atomic_bfloat16"
15+
; CHECK-SPIRV-DAG: Extension "SPV_KHR_bfloat16"
16+
17+
; CHECK-SPIRV: TypeFloat [[BFLOAT:[0-9]+]] 16 0
18+
19+
; Function Attrs: convergent norecurse nounwind
20+
define dso_local spir_func bfloat @test_AtomicFAddEXT_bfloat(ptr addrspace(4) align 2 dereferenceable(4) %Arg) {
21+
entry:
22+
%0 = addrspacecast ptr addrspace(4) %Arg to ptr addrspace(1)
23+
; CHECK-SPIRV: AtomicFAddEXT [[BFLOAT]]
24+
; CHECK-LLVM-SPV: call spir_func bfloat @_Z21__spirv_AtomicFAddEXTPU3AS1u6__bf16iiu6__bf16({{.*}}bfloat
25+
%ret = tail call spir_func bfloat @_Z21__spirv_AtomicFAddEXTPU3AS1u6__bf16iiu6__bf16(ptr addrspace(1) %0, i32 1, i32 896, bfloat 1.000000e+00)
26+
ret bfloat %ret
27+
}
28+
29+
; Function Attrs: convergent
30+
declare dso_local spir_func bfloat @_Z21__spirv_AtomicFAddEXTPU3AS1u6__bf16iiu6__bf16(ptr addrspace(1), i32, i32, bfloat)
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
; RUN: llvm-as %s -o %t.bc
2+
; RUN: llvm-spirv %t.bc --spirv-ext=+SPV_INTEL_shader_atomic_bfloat16,+SPV_KHR_bfloat16 -o %t.spv
3+
; RUN: llvm-spirv -to-text %t.spv -o %t.spt
4+
; RUN: FileCheck < %t.spt %s --check-prefix=CHECK-SPIRV
5+
6+
; RUN: llvm-spirv --spirv-target-env=SPV-IR -r %t.spv -o %t.rev.bc
7+
; RUN: llvm-dis %t.rev.bc -o - | FileCheck %s --check-prefixes=CHECK-LLVM-SPV
8+
9+
target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64"
10+
target triple = "spir64-unknown-unknown"
11+
12+
; CHECK-SPIRV-DAG: Capability AtomicBFloat16MinMaxINTEL
13+
; CHECK-SPIRV-DAG: Capability BFloat16TypeKHR
14+
; CHECK-SPIRV-DAG: Extension "SPV_INTEL_shader_atomic_bfloat16"
15+
; CHECK-SPIRV-DAG: Extension "SPV_KHR_bfloat16"
16+
17+
; CHECK-SPIRV: TypeFloat [[BFLOAT:[0-9]+]] 16 0
18+
19+
; Function Attrs: convergent norecurse nounwind
20+
define dso_local spir_func bfloat @test_AtomicFMaxEXT_bfloat(ptr addrspace(4) align 2 dereferenceable(4) %Arg) {
21+
entry:
22+
%0 = addrspacecast ptr addrspace(4) %Arg to ptr addrspace(1)
23+
; CHECK-SPIRV: AtomicFMaxEXT [[BFLOAT]]
24+
; CHECK-LLVM-SPV: call spir_func bfloat @_Z21__spirv_AtomicFMaxEXTPU3AS1u6__bf16iiu6__bf16({{.*}}bfloat
25+
%ret = tail call spir_func bfloat @_Z21__spirv_AtomicFMaxEXTPU3AS1u6__bf16iiu6__bf16(ptr addrspace(1) %0, i32 1, i32 896, bfloat 1.000000e+00)
26+
ret bfloat %ret
27+
}
28+
29+
; Function Attrs: convergent
30+
declare dso_local spir_func bfloat @_Z21__spirv_AtomicFMaxEXTPU3AS1u6__bf16iiu6__bf16(ptr addrspace(1), i32, i32, bfloat)
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
; RUN: llvm-as %s -o %t.bc
2+
; RUN: llvm-spirv %t.bc --spirv-ext=+SPV_INTEL_shader_atomic_bfloat16,+SPV_KHR_bfloat16 -o %t.spv
3+
; RUN: llvm-spirv -to-text %t.spv -o %t.spt
4+
; RUN: FileCheck < %t.spt %s --check-prefix=CHECK-SPIRV
5+
6+
; RUN: llvm-spirv --spirv-target-env=SPV-IR -r %t.spv -o %t.rev.bc
7+
; RUN: llvm-dis %t.rev.bc -o - | FileCheck %s --check-prefixes=CHECK-LLVM-SPV
8+
9+
target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64"
10+
target triple = "spir64-unknown-unknown"
11+
12+
; CHECK-SPIRV-DAG: Capability AtomicBFloat16MinMaxINTEL
13+
; CHECK-SPIRV-DAG: Capability BFloat16TypeKHR
14+
; CHECK-SPIRV-DAG: Extension "SPV_INTEL_shader_atomic_bfloat16"
15+
; CHECK-SPIRV-DAG: Extension "SPV_KHR_bfloat16"
16+
17+
; CHECK-SPIRV: TypeFloat [[BFLOAT:[0-9]+]] 16 0
18+
19+
; Function Attrs: convergent norecurse nounwind
20+
define dso_local spir_func bfloat @test_AtomicFMinEXT_bfloat(ptr addrspace(4) align 2 dereferenceable(4) %Arg) {
21+
entry:
22+
%0 = addrspacecast ptr addrspace(4) %Arg to ptr addrspace(1)
23+
; CHECK-SPIRV: AtomicFMinEXT [[BFLOAT]]
24+
; CHECK-LLVM-SPV: call spir_func bfloat @_Z21__spirv_AtomicFMinEXTPU3AS1u6__bf16iiu6__bf16({{.*}}bfloat
25+
%ret = tail call spir_func bfloat @_Z21__spirv_AtomicFMinEXTPU3AS1u6__bf16iiu6__bf16(ptr addrspace(1) %0, i32 1, i32 896, bfloat 1.000000e+00)
26+
ret bfloat %ret
27+
}
28+
29+
; Function Attrs: convergent
30+
declare dso_local spir_func bfloat @_Z21__spirv_AtomicFMinEXTPU3AS1u6__bf16iiu6__bf16(ptr addrspace(1), i32, i32, bfloat)
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
; RUN: llvm-as < %s -o %t.bc
2+
; RUN: llvm-spirv --spirv-ext=+SPV_INTEL_shader_atomic_bfloat16,+SPV_KHR_bfloat16 %t.bc -o %t.spv
3+
; RUN: llvm-spirv -to-text %t.spv -o - | FileCheck %s
4+
5+
; CHECK-DAG: Extension "SPV_INTEL_shader_atomic_bfloat16"
6+
; CHECK-DAG: Extension "SPV_KHR_bfloat16"
7+
; CHECK-DAG: Capability AtomicBFloat16AddINTEL
8+
; CHECK-DAG: Capability BFloat16TypeKHR
9+
; CHECK: TypeInt [[Int:[0-9]+]] 32 0
10+
; CHECK-DAG: Constant [[Int]] [[Scope_CrossDevice:[0-9]+]] 0 {{$}}
11+
; CHECK-DAG: Constant [[Int]] [[MemSem_SequentiallyConsistent:[0-9]+]] 16
12+
; CHECK: TypeFloat [[BFloat:[0-9]+]] 16 0
13+
; CHECK: Variable {{[0-9]+}} [[BFloatPointer:[0-9]+]]
14+
; CHECK: Constant [[BFloat]] [[BFloatValue:[0-9]+]] 16936
15+
16+
target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024"
17+
target triple = "spir64"
18+
19+
@f = common dso_local local_unnamed_addr addrspace(1) global bfloat 0.000000e+00, align 8
20+
21+
; Function Attrs: nounwind
22+
define dso_local spir_func void @test_atomicrmw_fadd() local_unnamed_addr #0 {
23+
entry:
24+
%0 = atomicrmw fadd ptr addrspace(1) @f, bfloat 42.000000e+00 seq_cst
25+
; CHECK: AtomicFAddEXT [[BFloat]] {{[0-9]+}} [[BFloatPointer]] [[Scope_CrossDevice]] [[MemSem_SequentiallyConsistent]] [[BFloatValue]]
26+
27+
ret void
28+
}
29+
30+
attributes #0 = { nounwind "correctly-rounded-divide-sqrt-fp-math"="false" "disable-tail-calls"="false" "frame-pointer"="all" "less-precise-fpmad"="false" "min-legal-vector-width"="0" "no-infs-fp-math"="false" "no-jump-tables"="false" "no-nans-fp-math"="false" "no-signed-zeros-fp-math"="false" "no-trapping-math"="false" "stack-protector-buffer-size"="8" "unsafe-fp-math"="false" "use-soft-float"="false" }
31+
32+
!llvm.module.flags = !{!0}
33+
34+
!0 = !{i32 1, !"wchar_size", i32 4}
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
; RUN: llvm-as < %s -o %t.bc
2+
; RUN: llvm-spirv --spirv-ext=+SPV_INTEL_shader_atomic_bfloat16,+SPV_KHR_bfloat16 %t.bc -o %t.spv
3+
; RUN: llvm-spirv -to-text %t.spv -o - | FileCheck %s
4+
5+
; CHECK-DAG: Extension "SPV_INTEL_shader_atomic_bfloat16"
6+
; CHECK-DAG: Extension "SPV_KHR_bfloat16"
7+
; CHECK-DAG: AtomicBFloat16MinMaxINTEL
8+
; CHECK-DAG: Capability BFloat16TypeKHR
9+
; CHECK: TypeInt [[Int:[0-9]+]] 32 0
10+
; CHECK-DAG: Constant [[Int]] [[Scope_CrossDevice:[0-9]+]] 0 {{$}}
11+
; CHECK-DAG: Constant [[Int]] [[MemSem_SequentiallyConsistent:[0-9]+]] 16
12+
; CHECK: TypeFloat [[BFloat:[0-9]+]] 16 0
13+
; CHECK: Variable {{[0-9]+}} [[BFloatPointer:[0-9]+]]
14+
; CHECK: Constant [[BFloat]] [[BFloatValue:[0-9]+]] 16936
15+
16+
target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024"
17+
target triple = "spir64"
18+
19+
@f = common dso_local local_unnamed_addr addrspace(1) global bfloat 0.000000e+00, align 4
20+
21+
; Function Attrs: nounwind
22+
define dso_local spir_func void @test_atomicrmw_fadd() local_unnamed_addr #0 {
23+
entry:
24+
%0 = atomicrmw fmin ptr addrspace(1) @f, bfloat 42.000000e+00 seq_cst
25+
; CHECK: AtomicFMinEXT [[BFloat]] {{[0-9]+}} [[BFloatPointer]] [[Scope_CrossDevice]] [[MemSem_SequentiallyConsistent]] [[BFloatValue]]
26+
%1 = atomicrmw fmax ptr addrspace(1) @f, bfloat 42.000000e+00 seq_cst
27+
; CHECK: AtomicFMaxEXT [[BFloat]] {{[0-9]+}} [[BFloatPointer]] [[Scope_CrossDevice]] [[MemSem_SequentiallyConsistent]] [[BFloatValue]]
28+
29+
ret void
30+
}
31+
32+
attributes #0 = { nounwind "correctly-rounded-divide-sqrt-fp-math"="false" "disable-tail-calls"="false" "frame-pointer"="all" "less-precise-fpmad"="false" "min-legal-vector-width"="0" "no-infs-fp-math"="false" "no-jump-tables"="false" "no-nans-fp-math"="false" "no-signed-zeros-fp-math"="false" "no-trapping-math"="false" "stack-protector-buffer-size"="8" "unsafe-fp-math"="false" "use-soft-float"="false" }
33+
34+
!llvm.module.flags = !{!0}
35+
36+
!0 = !{i32 1, !"wchar_size", i32 4}
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
; RUN: llvm-as < %s -o %t.bc
2+
; RUN: not llvm-spirv --spirv-ext=+SPV_INTEL_shader_atomic_bfloat16 %t.bc 2>&1 | FileCheck %s --check-prefix=CHECK-NO-BF
3+
; RUN: not llvm-spirv --spirv-ext=+SPV_KHR_bfloat16 %t.bc 2>&1 | FileCheck %s --check-prefix=CHECK-NO-ATOM
4+
5+
; CHECK-NO-BF: RequiresExtension: Feature requires the following SPIR-V extension:
6+
; CHECK-NO-BF-NEXT: SPV_KHR_bfloat16
7+
; CHECK-NO-BF-NEXT: NOTE: LLVM module contains bfloat type, translation of which requires this extension
8+
9+
; CHECK-NO-ATOM: RequiresExtension: Feature requires the following SPIR-V extension:
10+
; CHECK-NO-ATOM-NEXT: SPV_INTEL_shader_atomic_bfloat16
11+
12+
target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024"
13+
target triple = "spir64"
14+
15+
@f = common dso_local local_unnamed_addr addrspace(1) global bfloat 0.000000e+00, align 8
16+
17+
; Function Attrs: nounwind
18+
define dso_local spir_func void @test_atomicrmw_fadd() local_unnamed_addr #0 {
19+
entry:
20+
%0 = atomicrmw fadd ptr addrspace(1) @f, bfloat 42.000000e+00 seq_cst
21+
22+
ret void
23+
}
24+
25+
attributes #0 = { nounwind "correctly-rounded-divide-sqrt-fp-math"="false" "disable-tail-calls"="false" "frame-pointer"="all" "less-precise-fpmad"="false" "min-legal-vector-width"="0" "no-infs-fp-math"="false" "no-jump-tables"="false" "no-nans-fp-math"="false" "no-signed-zeros-fp-math"="false" "no-trapping-math"="false" "stack-protector-buffer-size"="8" "unsafe-fp-math"="false" "use-soft-float"="false" }
26+
27+
!llvm.module.flags = !{!0}
28+
29+
!0 = !{i32 1, !"wchar_size", i32 4}

0 commit comments

Comments
 (0)