Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions include/LLVMSPIRVExtensions.inc
Original file line number Diff line number Diff line change
Expand Up @@ -79,3 +79,4 @@ EXT(SPV_INTEL_bfloat16_arithmetic)
EXT(SPV_INTEL_ternary_bitwise_function)
EXT(SPV_INTEL_int4)
EXT(SPV_INTEL_function_variants)
EXT(SPV_INTEL_shader_atomic_bfloat16)
14 changes: 11 additions & 3 deletions lib/SPIRV/libSPIRV/SPIRVInstruction.h
Original file line number Diff line number Diff line change
Expand Up @@ -2864,40 +2864,48 @@ class SPIRVAtomicFAddEXTInst : public SPIRVAtomicInstBase {
public:
std::optional<ExtensionID> getRequiredExtension() const override {
assert(hasType());
if (getType()->isTypeFloat(16, FPEncodingBFloat16KHR))
return ExtensionID::SPV_INTEL_shader_atomic_bfloat16;
if (getType()->isTypeFloat(16))
return ExtensionID::SPV_EXT_shader_atomic_float16_add;
return ExtensionID::SPV_EXT_shader_atomic_float_add;
}

SPIRVCapVec getRequiredCapability() const override {
assert(hasType());
if (getType()->isTypeFloat(16, FPEncodingBFloat16KHR))
return {internal::CapabilityAtomicBFloat16AddINTEL};
if (getType()->isTypeFloat(16))
return {CapabilityAtomicFloat16AddEXT};
if (getType()->isTypeFloat(32))
return {CapabilityAtomicFloat32AddEXT};
if (getType()->isTypeFloat(64))
return {CapabilityAtomicFloat64AddEXT};
llvm_unreachable(
"AtomicFAddEXT can only be generated for f16, f32, f64 types");
"AtomicFAddEXT can only be generated for bf16, f16, f32, f64 types");
}
};

class SPIRVAtomicFMinMaxEXTBase : public SPIRVAtomicInstBase {
public:
std::optional<ExtensionID> getRequiredExtension() const override {
if (getType()->isTypeFloat(16, FPEncodingBFloat16KHR))
return ExtensionID::SPV_INTEL_shader_atomic_bfloat16;
return ExtensionID::SPV_EXT_shader_atomic_float_min_max;
}

SPIRVCapVec getRequiredCapability() const override {
assert(hasType());
if (getType()->isTypeFloat(16, FPEncodingBFloat16KHR))
return {internal::CapabilityAtomicBFloat16MinMaxINTEL};
if (getType()->isTypeFloat(16))
return {CapabilityAtomicFloat16MinMaxEXT};
if (getType()->isTypeFloat(32))
return {CapabilityAtomicFloat32MinMaxEXT};
if (getType()->isTypeFloat(64))
return {CapabilityAtomicFloat64MinMaxEXT};
llvm_unreachable(
"AtomicF(Min|Max)EXT can only be generated for f16, f32, f64 types");
llvm_unreachable("AtomicF(Min|Max)EXT can only be generated for bf16, f16, "
"f32, f64 types");
}
};

Expand Down
3 changes: 3 additions & 0 deletions lib/SPIRV/libSPIRV/SPIRVNameMapEnum.h
Original file line number Diff line number Diff line change
Expand Up @@ -616,6 +616,9 @@ template <> inline void SPIRVMap<Capability, std::string>::init() {
add(CapabilityLongCompositesINTEL, "LongCompositesINTEL");
add(CapabilityOptNoneEXT, "OptNoneEXT");
add(CapabilityAtomicFloat16AddEXT, "AtomicFloat16AddEXT");
add(internal::CapabilityAtomicBFloat16AddINTEL, "AtomicBFloat16AddINTEL");
add(internal::CapabilityAtomicBFloat16MinMaxINTEL,
"AtomicBFloat16MinMaxINTEL");
add(CapabilityDebugInfoModuleINTEL, "DebugInfoModuleINTEL");
add(CapabilitySplitBarrierINTEL, "SplitBarrierINTEL");
add(CapabilityGlobalVariableFPGADecorationsINTEL,
Expand Down
5 changes: 5 additions & 0 deletions lib/SPIRV/libSPIRV/spirv_internal.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,8 @@ enum InternalCapability {
ICapGlobalVariableDecorationsINTEL = 6146,
ICapabilityCooperativeMatrixCheckedInstructionsINTEL = 6192,
ICapabilityBFloat16ArithmeticINTEL = 6226,
ICapabilityAtomicBFloat16AddINTEL = 6255,
ICapabilityAtomicBFloat16MinMaxINTEL = 6256,
ICapabilityCooperativeMatrixPrefetchINTEL = 6411,
ICapabilityComplexFloatMulDivINTEL = 6414,
ICapabilityTensorFloat32RoundingINTEL = 6425,
Expand Down Expand Up @@ -170,6 +172,9 @@ _SPIRV_OP(Capability, BindlessImagesINTEL)
_SPIRV_OP(Op, ConvertHandleToImageINTEL)
_SPIRV_OP(Op, ConvertHandleToSamplerINTEL)
_SPIRV_OP(Op, ConvertHandleToSampledImageINTEL)

_SPIRV_OP(Capability, AtomicBFloat16AddINTEL)
_SPIRV_OP(Capability, AtomicBFloat16MinMaxINTEL)
#undef _SPIRV_OP

constexpr SourceLanguage SourceLanguagePython =
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
; RUN: llvm-as %s -o %t.bc
; RUN: llvm-spirv %t.bc --spirv-ext=+SPV_INTEL_shader_atomic_bfloat16,+SPV_KHR_bfloat16 -o %t.spv
; RUN: llvm-spirv -to-text %t.spv -o %t.spt
; RUN: FileCheck < %t.spt %s --check-prefix=CHECK-SPIRV

; RUN: llvm-spirv --spirv-target-env=SPV-IR -r %t.spv -o %t.rev.bc
; RUN: llvm-dis %t.rev.bc -o - | FileCheck %s --check-prefixes=CHECK-LLVM-SPV

target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64"
target triple = "spir64-unknown-unknown"

; CHECK-SPIRV-DAG: Capability AtomicBFloat16AddINTEL
; CHECK-SPIRV-DAG: Capability BFloat16TypeKHR
; CHECK-SPIRV-DAG: Extension "SPV_INTEL_shader_atomic_bfloat16"
; CHECK-SPIRV-DAG: Extension "SPV_KHR_bfloat16"

; CHECK-SPIRV: TypeFloat [[BFLOAT:[0-9]+]] 16 0

; Function Attrs: convergent norecurse nounwind
define dso_local spir_func bfloat @test_AtomicFAddEXT_bfloat(ptr addrspace(4) align 2 dereferenceable(4) %Arg) {
entry:
%0 = addrspacecast ptr addrspace(4) %Arg to ptr addrspace(1)
; CHECK-SPIRV: AtomicFAddEXT [[BFLOAT]]
; CHECK-LLVM-SPV: call spir_func bfloat @_Z21__spirv_AtomicFAddEXTPU3AS1u6__bf16iiu6__bf16({{.*}}bfloat
%ret = tail call spir_func bfloat @_Z21__spirv_AtomicFAddEXTPU3AS1u6__bf16iiu6__bf16(ptr addrspace(1) %0, i32 1, i32 896, bfloat 1.000000e+00)
ret bfloat %ret
}

; Function Attrs: convergent
declare dso_local spir_func bfloat @_Z21__spirv_AtomicFAddEXTPU3AS1u6__bf16iiu6__bf16(ptr addrspace(1), i32, i32, bfloat)
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
; RUN: llvm-as %s -o %t.bc
; RUN: llvm-spirv %t.bc --spirv-ext=+SPV_INTEL_shader_atomic_bfloat16,+SPV_KHR_bfloat16 -o %t.spv
; RUN: llvm-spirv -to-text %t.spv -o %t.spt
; RUN: FileCheck < %t.spt %s --check-prefix=CHECK-SPIRV

; RUN: llvm-spirv --spirv-target-env=SPV-IR -r %t.spv -o %t.rev.bc
; RUN: llvm-dis %t.rev.bc -o - | FileCheck %s --check-prefixes=CHECK-LLVM-SPV

target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64"
target triple = "spir64-unknown-unknown"

; CHECK-SPIRV-DAG: Capability AtomicBFloat16MinMaxINTEL
; CHECK-SPIRV-DAG: Capability BFloat16TypeKHR
; CHECK-SPIRV-DAG: Extension "SPV_INTEL_shader_atomic_bfloat16"
; CHECK-SPIRV-DAG: Extension "SPV_KHR_bfloat16"

; CHECK-SPIRV: TypeFloat [[BFLOAT:[0-9]+]] 16 0

; Function Attrs: convergent norecurse nounwind
define dso_local spir_func bfloat @test_AtomicFMaxEXT_bfloat(ptr addrspace(4) align 2 dereferenceable(4) %Arg) {
entry:
%0 = addrspacecast ptr addrspace(4) %Arg to ptr addrspace(1)
; CHECK-SPIRV: AtomicFMaxEXT [[BFLOAT]]
; CHECK-LLVM-SPV: call spir_func bfloat @_Z21__spirv_AtomicFMaxEXTPU3AS1u6__bf16iiu6__bf16({{.*}}bfloat
%ret = tail call spir_func bfloat @_Z21__spirv_AtomicFMaxEXTPU3AS1u6__bf16iiu6__bf16(ptr addrspace(1) %0, i32 1, i32 896, bfloat 1.000000e+00)
ret bfloat %ret
}

; Function Attrs: convergent
declare dso_local spir_func bfloat @_Z21__spirv_AtomicFMaxEXTPU3AS1u6__bf16iiu6__bf16(ptr addrspace(1), i32, i32, bfloat)
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
; RUN: llvm-as %s -o %t.bc
; RUN: llvm-spirv %t.bc --spirv-ext=+SPV_INTEL_shader_atomic_bfloat16,+SPV_KHR_bfloat16 -o %t.spv
; RUN: llvm-spirv -to-text %t.spv -o %t.spt
; RUN: FileCheck < %t.spt %s --check-prefix=CHECK-SPIRV

; RUN: llvm-spirv --spirv-target-env=SPV-IR -r %t.spv -o %t.rev.bc
; RUN: llvm-dis %t.rev.bc -o - | FileCheck %s --check-prefixes=CHECK-LLVM-SPV

target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64"
target triple = "spir64-unknown-unknown"

; CHECK-SPIRV-DAG: Capability AtomicBFloat16MinMaxINTEL
; CHECK-SPIRV-DAG: Capability BFloat16TypeKHR
; CHECK-SPIRV-DAG: Extension "SPV_INTEL_shader_atomic_bfloat16"
; CHECK-SPIRV-DAG: Extension "SPV_KHR_bfloat16"

; CHECK-SPIRV: TypeFloat [[BFLOAT:[0-9]+]] 16 0

; Function Attrs: convergent norecurse nounwind
define dso_local spir_func bfloat @test_AtomicFMinEXT_bfloat(ptr addrspace(4) align 2 dereferenceable(4) %Arg) {
entry:
%0 = addrspacecast ptr addrspace(4) %Arg to ptr addrspace(1)
; CHECK-SPIRV: AtomicFMinEXT [[BFLOAT]]
; CHECK-LLVM-SPV: call spir_func bfloat @_Z21__spirv_AtomicFMinEXTPU3AS1u6__bf16iiu6__bf16({{.*}}bfloat
%ret = tail call spir_func bfloat @_Z21__spirv_AtomicFMinEXTPU3AS1u6__bf16iiu6__bf16(ptr addrspace(1) %0, i32 1, i32 896, bfloat 1.000000e+00)
ret bfloat %ret
}

; Function Attrs: convergent
declare dso_local spir_func bfloat @_Z21__spirv_AtomicFMinEXTPU3AS1u6__bf16iiu6__bf16(ptr addrspace(1), i32, i32, bfloat)
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
; RUN: llvm-as < %s -o %t.bc
; RUN: llvm-spirv --spirv-ext=+SPV_INTEL_shader_atomic_bfloat16,+SPV_KHR_bfloat16 %t.bc -o %t.spv
; RUN: llvm-spirv -to-text %t.spv -o - | FileCheck %s

; CHECK-DAG: Extension "SPV_INTEL_shader_atomic_bfloat16"
; CHECK-DAG: Extension "SPV_KHR_bfloat16"
; CHECK-DAG: Capability AtomicBFloat16AddINTEL
; CHECK-DAG: Capability BFloat16TypeKHR
; CHECK: TypeInt [[Int:[0-9]+]] 32 0
; CHECK-DAG: Constant [[Int]] [[Device:[0-9]+]] 1 {{$}}
; CHECK-DAG: Constant [[Int]] [[MemSem_SequentiallyConsistent:[0-9]+]] 16
; CHECK: TypeFloat [[BFloat:[0-9]+]] 16 0
; CHECK: Variable {{[0-9]+}} [[BFloatPointer:[0-9]+]]
; CHECK: Constant [[BFloat]] [[BFloatValue:[0-9]+]] 16936

target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024"
target triple = "spir64"

@f = common dso_local local_unnamed_addr addrspace(1) global bfloat 0.000000e+00, align 8

; Function Attrs: nounwind
define dso_local spir_func void @test_atomicrmw_fadd() local_unnamed_addr #0 {
entry:
%0 = atomicrmw fadd ptr addrspace(1) @f, bfloat 42.000000e+00 seq_cst
; CHECK: AtomicFAddEXT [[BFloat]] {{[0-9]+}} [[BFloatPointer]] [[Device]] [[MemSem_SequentiallyConsistent]] [[BFloatValue]]

ret void
}

attributes #0 = { nounwind "correctly-rounded-divide-sqrt-fp-math"="false" "disable-tail-calls"="false" "frame-pointer"="all" "less-precise-fpmad"="false" "min-legal-vector-width"="0" "no-infs-fp-math"="false" "no-jump-tables"="false" "no-nans-fp-math"="false" "no-signed-zeros-fp-math"="false" "no-trapping-math"="false" "stack-protector-buffer-size"="8" "unsafe-fp-math"="false" "use-soft-float"="false" }

!llvm.module.flags = !{!0}

!0 = !{i32 1, !"wchar_size", i32 4}
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
; RUN: llvm-as < %s -o %t.bc
; RUN: llvm-spirv --spirv-ext=+SPV_INTEL_shader_atomic_bfloat16,+SPV_KHR_bfloat16 %t.bc -o %t.spv
; RUN: llvm-spirv -to-text %t.spv -o - | FileCheck %s

; CHECK-DAG: Extension "SPV_INTEL_shader_atomic_bfloat16"
; CHECK-DAG: Extension "SPV_KHR_bfloat16"
; CHECK-DAG: AtomicBFloat16MinMaxINTEL
; CHECK-DAG: Capability BFloat16TypeKHR
; CHECK: TypeInt [[Int:[0-9]+]] 32 0
; CHECK-DAG: Constant [[Int]] [[Device:[0-9]+]] 1 {{$}}
; CHECK-DAG: Constant [[Int]] [[MemSem_SequentiallyConsistent:[0-9]+]] 16
; CHECK: TypeFloat [[BFloat:[0-9]+]] 16 0
; CHECK: Variable {{[0-9]+}} [[BFloatPointer:[0-9]+]]
; CHECK: Constant [[BFloat]] [[BFloatValue:[0-9]+]] 16936

target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024"
target triple = "spir64"

@f = common dso_local local_unnamed_addr addrspace(1) global bfloat 0.000000e+00, align 4

; Function Attrs: nounwind
define dso_local spir_func void @test_atomicrmw_fadd() local_unnamed_addr #0 {
entry:
%0 = atomicrmw fmin ptr addrspace(1) @f, bfloat 42.000000e+00 seq_cst
; CHECK: AtomicFMinEXT [[BFloat]] {{[0-9]+}} [[BFloatPointer]] [[Device]] [[MemSem_SequentiallyConsistent]] [[BFloatValue]]
%1 = atomicrmw fmax ptr addrspace(1) @f, bfloat 42.000000e+00 seq_cst
; CHECK: AtomicFMaxEXT [[BFloat]] {{[0-9]+}} [[BFloatPointer]] [[Device]] [[MemSem_SequentiallyConsistent]] [[BFloatValue]]

ret void
}

attributes #0 = { nounwind "correctly-rounded-divide-sqrt-fp-math"="false" "disable-tail-calls"="false" "frame-pointer"="all" "less-precise-fpmad"="false" "min-legal-vector-width"="0" "no-infs-fp-math"="false" "no-jump-tables"="false" "no-nans-fp-math"="false" "no-signed-zeros-fp-math"="false" "no-trapping-math"="false" "stack-protector-buffer-size"="8" "unsafe-fp-math"="false" "use-soft-float"="false" }

!llvm.module.flags = !{!0}

!0 = !{i32 1, !"wchar_size", i32 4}
29 changes: 29 additions & 0 deletions test/extensions/INTEL/SPV_INTEL_shader_atomic_bfloat16/negative.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
; RUN: llvm-as < %s -o %t.bc
; RUN: not llvm-spirv --spirv-ext=+SPV_INTEL_shader_atomic_bfloat16 %t.bc 2>&1 | FileCheck %s --check-prefix=CHECK-NO-BF
; RUN: not llvm-spirv --spirv-ext=+SPV_KHR_bfloat16 %t.bc 2>&1 | FileCheck %s --check-prefix=CHECK-NO-ATOM

; CHECK-NO-BF: RequiresExtension: Feature requires the following SPIR-V extension:
; CHECK-NO-BF-NEXT: SPV_KHR_bfloat16
; CHECK-NO-BF-NEXT: NOTE: LLVM module contains bfloat type, translation of which requires this extension

; CHECK-NO-ATOM: RequiresExtension: Feature requires the following SPIR-V extension:
; CHECK-NO-ATOM-NEXT: SPV_INTEL_shader_atomic_bfloat16

target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024"
target triple = "spir64"

@f = common dso_local local_unnamed_addr addrspace(1) global bfloat 0.000000e+00, align 8

; Function Attrs: nounwind
define dso_local spir_func void @test_atomicrmw_fadd() local_unnamed_addr #0 {
entry:
%0 = atomicrmw fadd ptr addrspace(1) @f, bfloat 42.000000e+00 seq_cst

ret void
}

attributes #0 = { nounwind "correctly-rounded-divide-sqrt-fp-math"="false" "disable-tail-calls"="false" "frame-pointer"="all" "less-precise-fpmad"="false" "min-legal-vector-width"="0" "no-infs-fp-math"="false" "no-jump-tables"="false" "no-nans-fp-math"="false" "no-signed-zeros-fp-math"="false" "no-trapping-math"="false" "stack-protector-buffer-size"="8" "unsafe-fp-math"="false" "use-soft-float"="false" }

!llvm.module.flags = !{!0}

!0 = !{i32 1, !"wchar_size", i32 4}
Loading