Skip to content

Commit db69fa8

Browse files
Add atomic float support (#399)
Co-authored-by: Simeon David Schaub <[email protected]>
1 parent 71cf159 commit db69fa8

File tree

3 files changed

+117
-26
lines changed

3 files changed

+117
-26
lines changed

lib/intrinsics/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "SPIRVIntrinsics"
22
uuid = "71d1d633-e7e8-4a92-83a1-de8814b09ba8"
33
authors = ["Tim Besard <[email protected]>"]
4-
version = "0.5.5"
4+
version = "0.5.6"
55

66
[deps]
77
ExprTools = "e2ba6199-217a-4e67-a87a-7c52f15ade04"

lib/intrinsics/src/atomic.jl

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,15 @@
33
# provides atomic functions that rely on the OpenCL base atomics, as well as the
44
# cl_khr_int64_base_atomics and cl_khr_int64_extended_atomics extensions.
55

6+
const atomic_float_types = [Float32, Float64]
67
const atomic_integer_types = [UInt32, Int32, UInt64, Int64]
78
const atomic_memory_types = [AS.Workgroup, AS.CrossWorkgroup]
9+
const atomic_types = vcat(atomic_float_types, atomic_integer_types)
810

911

1012
# generically typed
1113

12-
for gentype in atomic_integer_types, as in atomic_memory_types
14+
for gentype in atomic_types, as in atomic_memory_types
1315
@eval begin
1416

1517
@device_function atomic_add!(p::LLVMPtr{$gentype,$as}, val::$gentype) =
@@ -45,15 +47,17 @@ for gentype in atomic_integer_types, as in atomic_memory_types
4547
@device_function atomic_xor!(p::LLVMPtr{$gentype,$as}, val::$gentype) =
4648
@builtin_ccall("atomic_xor", $gentype,
4749
(LLVMPtr{$gentype,$as}, $gentype), p, val)
48-
49-
@device_function atomic_xchg!(p::LLVMPtr{$gentype,$as}, val::$gentype) =
50-
@builtin_ccall("atomic_xchg", $gentype,
51-
(LLVMPtr{$gentype,$as}, $gentype), p, val)
52-
53-
@device_function atomic_cmpxchg!(p::LLVMPtr{$gentype,$as}, cmp::$gentype, val::$gentype) =
54-
@builtin_ccall("atomic_cmpxchg", $gentype,
55-
(LLVMPtr{$gentype,$as}, $gentype, $gentype), p, cmp, val)
56-
50+
end
51+
if gentype in atomic_integer_types
52+
@eval begin
53+
@device_function atomic_xchg!(p::LLVMPtr{$gentype,$as}, val::$gentype) =
54+
@builtin_ccall("atomic_xchg", $gentype,
55+
(LLVMPtr{$gentype,$as}, $gentype), p, val)
56+
57+
@device_function atomic_cmpxchg!(p::LLVMPtr{$gentype,$as}, cmp::$gentype, val::$gentype) =
58+
@builtin_ccall("atomic_cmpxchg", $gentype,
59+
(LLVMPtr{$gentype,$as}, $gentype, $gentype), p, cmp, val)
60+
end
5761
end
5862
end
5963

test/atomics.jl

Lines changed: 102 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,41 +1,128 @@
11
using SPIRVIntrinsics: @builtin_ccall, @typed_ccall, LLVMPtr, known_intrinsics
22

3-
@testset "atomics" begin
3+
# Define the types to test
4+
integer_types = [Int32, UInt32, Int64, UInt64]
5+
float_types = [Float32, Float64]
6+
all_types = vcat(integer_types, float_types)
7+
8+
dev = OpenCL.cl.device()
49

5-
function atomic_count(counter)
6-
OpenCL.@atomic counter[] += 1
10+
# Arithmetic operations
11+
function test_atomic_add(counter::AbstractArray{T}) where T
12+
OpenCL.@atomic counter[] += one(T)
13+
return
14+
end
15+
function test_atomic_sub(counter::AbstractArray{T}) where T
16+
OpenCL.@atomic counter[] -= one(T)
17+
return
18+
end
19+
# Bitwise operations
20+
function test_atomic_and(counter::AbstractArray{T}) where T
21+
OpenCL.@atomic counter[] &= ~(one(T) << (get_global_id() - 1))
22+
return
23+
end
24+
function test_atomic_or(counter::AbstractArray{T}) where T
25+
OpenCL.@atomic counter[] |= one(T) << (get_global_id() - 1)
26+
return
27+
end
28+
function test_atomic_xor(counter::AbstractArray{T}) where T
29+
OpenCL.@atomic counter[] ⊻= one(T) << ((get_global_id() - 1) % 32)
30+
return
31+
end
32+
# Min/max operations - use low-level API directly
33+
function test_atomic_max(counter::AbstractArray{T}) where T
34+
OpenCL.atomic_max!(pointer(counter), T(get_global_id()))
35+
return
36+
end
37+
function test_atomic_min(counter::AbstractArray{T}) where T
38+
OpenCL.atomic_min!(pointer(counter), T(get_global_id()))
39+
return
40+
end
41+
# Exchange operation - use low-level API directly
42+
function test_atomic_xchg(counter::AbstractArray{T}) where T
43+
OpenCL.atomic_xchg!(pointer(counter), one(T))
744
return
845
end
46+
# Compare-and-swap operation - use low-level API directly
47+
function test_atomic_cas(counter::AbstractArray{T}) where T
48+
OpenCL.atomic_cmpxchg!(pointer(counter), zero(T), one(T))
49+
return
50+
end
51+
52+
# Define atomic operations to test
53+
atomic_operations = [
54+
# op, init_val, expected_val
55+
(test_atomic_add, 0, 1000),
56+
(test_atomic_sub, 1000, 0),
57+
(test_atomic_and, typemax(UInt64), 0),
58+
(test_atomic_or, 0, typemax(UInt64)),
59+
(test_atomic_xor, 0, typemax(UInt32) << 8),
60+
(test_atomic_max, 0, 1000),
61+
(test_atomic_min, 1000, 1),
62+
(test_atomic_xchg, 0, 1),
63+
(test_atomic_cas, 0, 1),
64+
]
65+
@testset "atomics" begin
66+
@testset "$kernel_func - $T" for (kernel_func, init_val, expected_val) in atomic_operations, T in all_types
67+
# Skip Int64/UInt64 if not supported
68+
if sizeof(T) == 8 && T <: Integer && !("cl_khr_int64_extended_atomics" in dev.extensions)
69+
continue
70+
end
71+
72+
# Skip Float64 if not supported
73+
if T == Float64 && !("cl_khr_fp64" in dev.extensions)
74+
continue
75+
end
976

10-
@testset "atomic_add! ($T)" for T in [Int32, UInt32, Int64, UInt64]
11-
if sizeof(T) == 4 || "cl_khr_int64_extended_atomics" in cl.device().extensions
12-
a = OpenCL.zeros(T)
13-
@opencl global_size=1000 atomic_count(a)
14-
@test OpenCL.@allowscalar a[] == 1000
77+
# Bitwise operations (only valid for integers)
78+
if kernel_func in [test_atomic_and, test_atomic_or, test_atomic_xor] && T <: AbstractFloat
79+
continue
1580
end
81+
82+
# Min/max operations (only supported for 32-bit integers in OpenCL)
83+
if kernel_func in [test_atomic_min, test_atomic_max] && !(T in [Int32, UInt32])
84+
continue
85+
end
86+
87+
if T <: Integer
88+
init_val %= T
89+
expected_val %= T
90+
end
91+
92+
a = OpenCL.fill(T(init_val))
93+
@opencl global_size=1000 kernel_func(a)
94+
result_val = OpenCL.@allowscalar a[]
95+
@test result_val === T(expected_val)
1696
end
1797

98+
99+
@testset "atomic_add! ($T)" for T in [Float32, Float64]
100+
# Float64 requires cl_khr_fp64 extension
101+
if T == Float64 && !("cl_khr_fp64" in cl.device().extensions)
102+
continue
103+
end
18104
if "cl_ext_float_atomics" in cl.device().extensions
19-
function atomic_float_add(counter, val)
105+
@eval function atomic_float_add(counter, val::$T)
20106
@builtin_ccall(
21-
"atomic_add", Float32,
22-
(LLVMPtr{Float32, AS.CrossWorkgroup}, Float32),
107+
"atomic_add", $T,
108+
(LLVMPtr{$T, AS.CrossWorkgroup}, $T),
23109
pointer(counter), val,
24110
)
25111
return
26112
end
27113

28114
@testset "SPV_EXT_shader_atomic_float_add extension" begin
29-
a = OpenCL.zeros(Float32)
30-
@opencl global_size = 1000 extensions = ["SPV_EXT_shader_atomic_float_add"] atomic_float_add(a, 1.0f0)
31-
@test OpenCL.@allowscalar a[] == 1000.0f0
115+
a = OpenCL.zeros(T)
116+
@opencl global_size = 1000 extensions = ["SPV_EXT_shader_atomic_float_add"] atomic_float_add(a, one(T))
117+
@test OpenCL.@allowscalar a[] == T(1000.0)
32118

33119
spv = sprint() do io
34-
OpenCL.code_native(io, atomic_float_add, Tuple{CLDeviceArray{Float32, 0, 1}, Float32}; extensions = ["SPV_EXT_shader_atomic_float_add"])
120+
OpenCL.code_native(io, atomic_float_add, Tuple{CLDeviceArray{T, 0, 1}, T}; extensions = ["SPV_EXT_shader_atomic_float_add"])
35121
end
36122
@test occursin("OpExtension \"SPV_EXT_shader_atomic_float_add\"", spv)
37123
@test occursin("OpAtomicFAddEXT", spv)
38124
end
39125
end
40126

41127
end
128+
end

0 commit comments

Comments
 (0)