Skip to content

Commit 4821fa7

Browse files
committed
Fix RNG initialization with 3D indices.
1 parent e1168de commit 4821fa7

File tree

1 file changed

+26
-23
lines changed

1 file changed

+26
-23
lines changed

src/device/random.jl

Lines changed: 26 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,10 @@ import RandomNumbers
66

77
# global state
88

9+
# 16 is the lower bound for `threads_per_simdgroup()`, 1024 is the upper bound
10+
# for `threads_per_threadgroup()`, so we can have 64 simdgroups per threadgroup
11+
const max_simdgroups_per_threadgroup = 64
12+
913
@inline @generated function emit_global_random_values(::Val{name}) where name
1014
@dispose ctx=Context() begin
1115
T_val = convert(LLVMType, UInt32)
@@ -16,7 +20,7 @@ import RandomNumbers
1620
mod = LLVM.parent(llvm_f)
1721

1822
# create a global memory global variable
19-
T_global = LLVM.ArrayType(T_val, 32)
23+
T_global = LLVM.ArrayType(T_val, max_simdgroups_per_threadgroup)
2024
gv = GlobalVariable(mod, T_global, "global_random_$(name)", AS.ThreadGroup)
2125
linkage!(gv, LLVM.API.LLVMLinkOnceAnyLinkage)
2226
initializer!(gv, LLVM.null(T_global))
@@ -39,26 +43,25 @@ import RandomNumbers
3943
end
4044
end
4145

42-
# shared memory with the actual seed, per warp, loaded lazily or overridden by calling `seed!`
46+
# shared memory with the actual seed, per simdgroup, loaded lazily or overridden by calling `seed!`
4347
@inline function global_random_keys()
4448
ptr = emit_global_random_values(Val{:keys}())::LLVMPtr{UInt32,AS.ThreadGroup}
45-
return MtlDeviceArray{UInt32,1,AS.ThreadGroup}((32,), ptr)
49+
return MtlDeviceArray{UInt32,1,AS.ThreadGroup}((max_simdgroups_per_threadgroup,), ptr)
4650
end
4751

48-
# shared memory with per-warp counters, incremented when generating numbers
52+
# shared memory with per-simdgroup counters, incremented when generating numbers
4953
@inline function global_random_counters()
5054
ptr = emit_global_random_values(Val{:counters}())::LLVMPtr{UInt32,AS.ThreadGroup}
51-
return MtlDeviceArray{UInt32,1,AS.ThreadGroup}((32,), ptr)
55+
return MtlDeviceArray{UInt32,1,AS.ThreadGroup}((max_simdgroups_per_threadgroup,), ptr)
5256
end
5357

5458
# initialization function, called automatically at the start of each kernel because
5559
# there's no reliable way to detect uninitialized shared memory (see JuliaGPU/CUDA.jl#2008)
5660
function initialize_rng_state()
57-
threadId = thread_position_in_threadgroup().x
58-
warpId = (threadId - Int32(1)) >> 0x5 + Int32(1) # fld1 by 32
61+
simdgroupId = simdgroup_index_in_threadgroup()
5962

60-
@inbounds global_random_keys()[warpId] = kernel_state().random_seed
61-
@inbounds global_random_counters()[warpId] = 0
63+
@inbounds global_random_keys()[simdgroupId] = kernel_state().random_seed
64+
@inbounds global_random_counters()[simdgroupId] = 0
6265
end
6366

6467
# generators
@@ -76,29 +79,29 @@ end
7679
@inline Philox2x32() = Philox2x32{7}()
7780

7881
@inline function Base.getproperty(rng::Philox2x32, field::Symbol)
79-
threadId = thread_position_in_threadgroup().x
80-
warpId = (threadId - Int32(1)) >> 0x5 + Int32(1) # fld1 by 32
82+
simdgroupId = simdgroup_index_in_threadgroup()
8183

8284
if field === :seed
8385
@inbounds global_random_seed()[1]
8486
elseif field === :key
85-
@inbounds global_random_keys()[warpId]
87+
@inbounds global_random_keys()[simdgroupId]
8688
elseif field === :ctr1
87-
@inbounds global_random_counters()[warpId]
89+
@inbounds global_random_counters()[simdgroupId]
8890
elseif field === :ctr2
89-
globalId = thread_position_in_grid().x
91+
globalId = thread_position_in_grid().x +
92+
(thread_position_in_grid().y - 1i32) * threads_per_grid().x +
93+
(thread_position_in_grid().z - 1i32) * threads_per_grid().x * threads_per_grid().y
9094
globalId % UInt32
9195
end::UInt32
9296
end
9397

9498
@inline function Base.setproperty!(rng::Philox2x32, field::Symbol, x)
95-
threadId = thread_position_in_threadgroup().x
96-
warpId = (threadId - Int32(1)) >> 0x5 + Int32(1) # fld1 by 32
99+
simdgroupId = simdgroup_index_in_threadgroup()
97100

98101
if field === :key
99-
@inbounds global_random_keys()[warpId] = x
102+
@inbounds global_random_keys()[simdgroupId] = x
100103
elseif field === :ctr1
101-
@inbounds global_random_counters()[warpId] = x
104+
@inbounds global_random_counters()[simdgroupId] = x
102105
end
103106
end
104107

@@ -108,7 +111,7 @@ end
108111
Random.seed!(rng::Philox2x32, seed::Integer, [counter::Integer=0])
109112
110113
Seed the on-device Philox2x32 generator with an UInt32 number.
111-
Should be called by at least one thread per warp.
114+
Should be called by at least one thread per simdgroup.
112115
"""
113116
function Random.seed!(rng::Philox2x32, seed::Integer, counter::Integer=UInt32(0))
114117
rng.key = seed % UInt32
@@ -128,7 +131,7 @@ end
128131
"""
129132
Random.rand(rng::Philox2x32, UInt32)
130133
131-
Generate a byte of random data using the on-device Tausworthe generator.
134+
Generate a byte of random data using the on-device Philox generator.
132135
"""
133136
function Random.rand(rng::Philox2x32{R},::Type{UInt64}) where {R}
134137
ctr1, ctr2, key = rng.ctr1, rng.ctr2, rng.key
@@ -150,9 +153,9 @@ function Random.rand(rng::Philox2x32{R},::Type{UInt64}) where {R}
150153
if R > 14 key = philox2x_bumpkey(key); ctr1, ctr2 = philox2x_round(ctr1, ctr2, key); end
151154
if R > 15 key = philox2x_bumpkey(key); ctr1, ctr2 = philox2x_round(ctr1, ctr2, key); end
152155

153-
# update the warp counter
154-
# NOTE: this performs the same update on every thread in the warp, but each warp writes
155-
# to a unique location so the duplicate writes are innocuous
156+
# update the simdgroup counter
157+
# NOTE: this performs the same update on every thread in the simdgroup, but each
158+
# simdgroup writes to a unique location so the duplicate writes are innocuous
156159
# XXX: what if this overflows? we can't increment ctr2. bump the key?
157160
rng.ctr1 += Int32(1)
158161

0 commit comments

Comments
 (0)