@@ -6,6 +6,10 @@ import RandomNumbers
66
77# global state
88
9+ # 16 is the lower bound for `threads_per_simdgroup()`, 1024 is the upper bound
10+ # for `threads_per_threadgroup()`, so we can have 64 simdgroups per threadgroup
11+ const max_simdgroups_per_threadgroup = 64
12+
913@inline @generated function emit_global_random_values (:: Val{name} ) where name
1014 @dispose ctx= Context () begin
1115 T_val = convert (LLVMType, UInt32)
@@ -16,7 +20,7 @@ import RandomNumbers
1620 mod = LLVM. parent (llvm_f)
1721
1822 # create a global memory global variable
19- T_global = LLVM. ArrayType (T_val, 32 )
23+ T_global = LLVM. ArrayType (T_val, max_simdgroups_per_threadgroup )
2024 gv = GlobalVariable (mod, T_global, " global_random_$(name) " , AS. ThreadGroup)
2125 linkage! (gv, LLVM. API. LLVMLinkOnceAnyLinkage)
2226 initializer! (gv, LLVM. null (T_global))
@@ -39,26 +43,25 @@ import RandomNumbers
3943 end
4044end
4145
42- # shared memory with the actual seed, per warp , loaded lazily or overridden by calling `seed!`
46+ # shared memory with the actual seed, per simdgroup , loaded lazily or overridden by calling `seed!`
4347@inline function global_random_keys ()
4448 ptr = emit_global_random_values (Val {:keys} ()):: LLVMPtr{UInt32,AS.ThreadGroup}
45- return MtlDeviceArray {UInt32,1,AS.ThreadGroup} ((32 ,), ptr)
49+ return MtlDeviceArray {UInt32,1,AS.ThreadGroup} ((max_simdgroups_per_threadgroup ,), ptr)
4650end
4751
48- # shared memory with per-warp counters, incremented when generating numbers
52+ # shared memory with per-simdgroup counters, incremented when generating numbers
4953@inline function global_random_counters ()
5054 ptr = emit_global_random_values (Val {:counters} ()):: LLVMPtr{UInt32,AS.ThreadGroup}
51- return MtlDeviceArray {UInt32,1,AS.ThreadGroup} ((32 ,), ptr)
55+ return MtlDeviceArray {UInt32,1,AS.ThreadGroup} ((max_simdgroups_per_threadgroup ,), ptr)
5256end
5357
5458# initialization function, called automatically at the start of each kernel because
5559# there's no reliable way to detect uninitialized shared memory (see JuliaGPU/CUDA.jl#2008)
5660function initialize_rng_state ()
57- threadId = thread_position_in_threadgroup (). x
58- warpId = (threadId - Int32 (1 )) >> 0x5 + Int32 (1 ) # fld1 by 32
61+ simdgroupId = simdgroup_index_in_threadgroup ()
5962
60- @inbounds global_random_keys ()[warpId ] = kernel_state (). random_seed
61- @inbounds global_random_counters ()[warpId ] = 0
63+ @inbounds global_random_keys ()[simdgroupId ] = kernel_state (). random_seed
64+ @inbounds global_random_counters ()[simdgroupId ] = 0
6265end
6366
6467# generators
7679@inline Philox2x32 () = Philox2x32 {7} ()
7780
7881@inline function Base. getproperty (rng:: Philox2x32 , field:: Symbol )
79- threadId = thread_position_in_threadgroup (). x
80- warpId = (threadId - Int32 (1 )) >> 0x5 + Int32 (1 ) # fld1 by 32
82+ simdgroupId = simdgroup_index_in_threadgroup ()
8183
8284 if field === :seed
8385 @inbounds global_random_seed ()[1 ]
8486 elseif field === :key
85- @inbounds global_random_keys ()[warpId ]
87+ @inbounds global_random_keys ()[simdgroupId ]
8688 elseif field === :ctr1
87- @inbounds global_random_counters ()[warpId ]
89+ @inbounds global_random_counters ()[simdgroupId ]
8890 elseif field === :ctr2
89- globalId = thread_position_in_grid (). x
91+ globalId = thread_position_in_grid (). x +
92+ (thread_position_in_grid (). y - 1 i32) * threads_per_grid (). x +
93+ (thread_position_in_grid (). z - 1 i32) * threads_per_grid (). x * threads_per_grid (). y
9094 globalId % UInt32
9195 end :: UInt32
9296end
9397
9498@inline function Base. setproperty! (rng:: Philox2x32 , field:: Symbol , x)
95- threadId = thread_position_in_threadgroup (). x
96- warpId = (threadId - Int32 (1 )) >> 0x5 + Int32 (1 ) # fld1 by 32
99+ simdgroupId = simdgroup_index_in_threadgroup ()
97100
98101 if field === :key
99- @inbounds global_random_keys ()[warpId ] = x
102+ @inbounds global_random_keys ()[simdgroupId ] = x
100103 elseif field === :ctr1
101- @inbounds global_random_counters ()[warpId ] = x
104+ @inbounds global_random_counters ()[simdgroupId ] = x
102105 end
103106end
104107
108111 Random.seed!(rng::Philox2x32, seed::Integer, [counter::Integer=0])
109112
110113Seed the on-device Philox2x32 generator with an UInt32 number.
111- Should be called by at least one thread per warp .
114+ Should be called by at least one thread per simdgroup .
112115"""
113116function Random. seed! (rng:: Philox2x32 , seed:: Integer , counter:: Integer = UInt32 (0 ))
114117 rng. key = seed % UInt32
128131"""
129132 Random.rand(rng::Philox2x32, UInt32)
130133
131- Generate a byte of random data using the on-device Tausworthe generator.
134+ Generate a byte of random data using the on-device Philox generator.
132135"""
133136function Random. rand (rng:: Philox2x32{R} ,:: Type{UInt64} ) where {R}
134137 ctr1, ctr2, key = rng. ctr1, rng. ctr2, rng. key
@@ -150,9 +153,9 @@ function Random.rand(rng::Philox2x32{R},::Type{UInt64}) where {R}
150153 if R > 14 key = philox2x_bumpkey (key); ctr1, ctr2 = philox2x_round (ctr1, ctr2, key); end
151154 if R > 15 key = philox2x_bumpkey (key); ctr1, ctr2 = philox2x_round (ctr1, ctr2, key); end
152155
153- # update the warp counter
154- # NOTE: this performs the same update on every thread in the warp , but each warp writes
155- # to a unique location so the duplicate writes are innocuous
156+ # update the simdgroup counter
157+ # NOTE: this performs the same update on every thread in the simdgroup , but each
158+ # simdgroup writes to a unique location so the duplicate writes are innocuous
156159 # XXX : what if this overflows? we can't increment ctr2. bump the key?
157160 rng. ctr1 += Int32 (1 )
158161
0 commit comments