Skip to content

Commit c7c9db1

Browse files
simeonschaubjpsamaroomaleadt
authored
Device-side RNG (#657)
Co-authored-by: Julian Samaroo <[email protected]> Co-authored-by: Tim Besard <[email protected]>
1 parent 1214dbd commit c7c9db1

File tree

8 files changed

+391
-3
lines changed

8 files changed

+391
-3
lines changed

Project.toml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
2020
Preferences = "21216c6a-2e73-6563-6e65-726566657250"
2121
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
2222
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
23+
Random123 = "74087812-796a-5b5d-8853-05524746bad3"
24+
RandomNumbers = "e6cf234a-135c-5ec9-84dd-332b85af5143"
2325
SHA = "ea8e919c-243c-51af-8825-aaa63cd721ce"
2426
ScopedValues = "7e506255-f358-4e82-b7e4-beb19740aa63"
2527
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
@@ -38,7 +40,7 @@ CEnum = "0.4, 0.5"
3840
CodecBzip2 = "0.8.5"
3941
ExprTools = "0.1"
4042
GPUArrays = "11.2.1"
41-
GPUCompiler = "0.26, 0.27, 1"
43+
GPUCompiler = "1.7.1"
4244
GPUToolbox = "0.1, 0.2, 0.3, 1"
4345
KernelAbstractions = "0.9.38"
4446
LLVM = "7.2, 8, 9"
@@ -49,6 +51,8 @@ PrecompileTools = "1"
4951
Preferences = "1"
5052
Printf = "1"
5153
Random = "1"
54+
Random123 = "1.7.1"
55+
RandomNumbers = "1.6.0"
5256
SHA = "0.7"
5357
ScopedValues = "1.3.0"
5458
SpecialFunctions = "2"

src/Metal.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ include("device/intrinsics/synchronization.jl")
3737
include("device/intrinsics/memory.jl")
3838
include("device/intrinsics/simd.jl")
3939
include("device/intrinsics/atomics.jl")
40+
include("device/random.jl")
4041
include("device/quirks.jl")
4142

4243
# array essentials

src/compiler/compilation.jl

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,71 @@ GPUCompiler.runtime_module(::MetalCompilerJob) = Metal
88

99
GPUCompiler.method_table(::MetalCompilerJob) = method_table
1010

11+
GPUCompiler.kernel_state_type(job::MetalCompilerJob) = KernelState
12+
13+
function GPUCompiler.finish_module!(@nospecialize(job::MetalCompilerJob),
14+
mod::LLVM.Module, entry::LLVM.Function)
15+
entry = invoke(GPUCompiler.finish_module!,
16+
Tuple{CompilerJob{MetalCompilerTarget}, LLVM.Module, LLVM.Function},
17+
job, mod, entry)
18+
19+
# if this kernel uses our RNG, we should prime the shared state.
20+
# XXX: these transformations should really happen at the Julia IR level...
21+
if job.config.kernel && haskey(globals(mod), "global_random_keys")
22+
f = initialize_rng_state
23+
ft = typeof(f)
24+
tt = Tuple{}
25+
26+
# create a deferred compilation job for `initialize_rng_state()`
27+
src = methodinstance(ft, tt, GPUCompiler.tls_world_age())
28+
cfg = CompilerConfig(job.config; kernel=false, name=nothing)
29+
job = CompilerJob(src, cfg, job.world)
30+
id = length(GPUCompiler.deferred_codegen_jobs) + 1
31+
GPUCompiler.deferred_codegen_jobs[id] = job
32+
33+
# generate IR for calls to `deferred_codegen` and the resulting function pointer
34+
top_bb = first(blocks(entry))
35+
bb = BasicBlock(top_bb, "initialize_rng")
36+
@dispose builder=IRBuilder() begin
37+
position!(builder, bb)
38+
subprogram = LLVM.subprogram(entry)
39+
if subprogram !== nothing
40+
loc = DILocation(0, 0, subprogram)
41+
debuglocation!(builder, loc)
42+
end
43+
debuglocation!(builder, first(instructions(top_bb)))
44+
45+
# call the `deferred_codegen` marker function
46+
T_ptr = if LLVM.version() >= v"17"
47+
LLVM.PointerType()
48+
elseif VERSION >= v"1.12.0-DEV.225"
49+
LLVM.PointerType(LLVM.Int8Type())
50+
else
51+
LLVM.Int64Type()
52+
end
53+
T_id = convert(LLVMType, Int)
54+
deferred_codegen_ft = LLVM.FunctionType(T_ptr, [T_id])
55+
deferred_codegen = if haskey(functions(mod), "deferred_codegen")
56+
functions(mod)["deferred_codegen"]
57+
else
58+
LLVM.Function(mod, "deferred_codegen", deferred_codegen_ft)
59+
end
60+
fptr = call!(builder, deferred_codegen_ft, deferred_codegen, [ConstantInt(id)])
61+
62+
# call the `initialize_rng_state` function
63+
rt = Core.Compiler.return_type(f, tt)
64+
llvm_rt = convert(LLVMType, rt)
65+
llvm_ft = LLVM.FunctionType(llvm_rt)
66+
fptr = inttoptr!(builder, fptr, LLVM.PointerType(llvm_ft))
67+
call!(builder, llvm_ft, fptr)
68+
br!(builder, top_bb)
69+
end
70+
71+
# XXX: put some of the above behind GPUCompiler abstractions
72+
# (e.g., a compile-time version of `deferred_codegen`)
73+
end
74+
return entry
75+
end
1176

1277
function GPUCompiler.finish_ir!(@nospecialize(job::MetalCompilerJob),
1378
mod::LLVM.Module, entry::LLVM.Function)

src/compiler/execution.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -275,12 +275,14 @@ end
275275
(threads.width * threads.height * threads.depth) > kernel.pipeline.maxTotalThreadsPerThreadgroup &&
276276
throw(ArgumentError("Number of threads in group ($(threads.width * threads.height * threads.depth)) should not exceed $(kernel.pipeline.maxTotalThreadsPerThreadgroup)"))
277277

278+
kernel_state = KernelState(rand(UInt32))
279+
278280
cmdbuf = MTLCommandBuffer(queue)
279281
cmdbuf.label = "MTLCommandBuffer($(nameof(kernel.f)))"
280282
cce = MTLComputeCommandEncoder(cmdbuf)
281283
argument_buffers = try
282284
MTL.set_function!(cce, kernel.pipeline)
283-
bufs = encode_arguments!(cce, kernel, kernel.f, args...)
285+
bufs = encode_arguments!(cce, kernel, kernel_state, kernel.f, args...)
284286
MTL.append_current_function!(cce, groups, threads)
285287
bufs
286288
finally
@@ -295,7 +297,7 @@ end
295297
# kernel has actually completed.
296298
#
297299
# TODO: is there a way to bind additional resources to the command buffer?
298-
roots = [kernel.f, args]
300+
roots = [kernel.f, kernel_state, args]
299301
MTL.on_completed(cmdbuf) do buf
300302
empty!(roots)
301303
foreach(free, argument_buffers)

src/device/random.jl

Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,180 @@
1+
## random number generation
2+
3+
using Random
4+
import RandomNumbers
5+
6+
7+
# global state
8+
9+
@inline @generated function emit_global_random_values(::Val{name}) where name
10+
@dispose ctx=Context() begin
11+
T_val = convert(LLVMType, UInt32)
12+
T_ptr = convert(LLVMType, LLVMPtr{UInt32,AS.ThreadGroup})
13+
14+
# define function and get LLVM module
15+
llvm_f, _ = create_function(T_ptr)
16+
mod = LLVM.parent(llvm_f)
17+
18+
# create a global memory global variable
19+
T_global = LLVM.ArrayType(T_val, 32)
20+
gv = GlobalVariable(mod, T_global, "global_random_$(name)", AS.ThreadGroup)
21+
linkage!(gv, LLVM.API.LLVMLinkOnceAnyLinkage)
22+
initializer!(gv, LLVM.null(T_global))
23+
unnamed_addr!(gv, true)
24+
alignment!(gv, 4)
25+
26+
# generate IR
27+
@dispose builder=IRBuilder() begin
28+
entry = BasicBlock(llvm_f, "entry")
29+
position!(builder, entry)
30+
31+
ptr = gep!(builder, T_global, gv, [ConstantInt(0), ConstantInt(0)])
32+
33+
untyped_ptr = bitcast!(builder, ptr, T_ptr)
34+
35+
ret!(builder, untyped_ptr)
36+
end
37+
38+
call_function(llvm_f, LLVMPtr{UInt32,AS.ThreadGroup})
39+
end
40+
end
41+
42+
# shared memory with the actual seed, per warp, loaded lazily or overridden by calling `seed!`
43+
@inline function global_random_keys()
44+
ptr = emit_global_random_values(Val{:keys}())::LLVMPtr{UInt32,AS.ThreadGroup}
45+
return MtlDeviceArray{UInt32,1,AS.ThreadGroup}((32,), ptr)
46+
end
47+
48+
# shared memory with per-warp counters, incremented when generating numbers
49+
@inline function global_random_counters()
50+
ptr = emit_global_random_values(Val{:counters}())::LLVMPtr{UInt32,AS.ThreadGroup}
51+
return MtlDeviceArray{UInt32,1,AS.ThreadGroup}((32,), ptr)
52+
end
53+
54+
# initialization function, called automatically at the start of each kernel because
55+
# there's no reliable way to detect uninitialized shared memory (see JuliaGPU/CUDA.jl#2008)
56+
function initialize_rng_state()
57+
threadId = thread_position_in_threadgroup_1d()
58+
warpId = (threadId - Int32(1)) >> 0x5 + Int32(1) # fld1 by 32
59+
60+
@inbounds global_random_keys()[warpId] = kernel_state().random_seed
61+
@inbounds global_random_counters()[warpId] = 0
62+
end
63+
64+
# generators
65+
66+
using Random123: philox2x_round, philox2x_bumpkey
67+
68+
# GPU-compatible/optimized version of the generator from Random123.jl
69+
struct Philox2x32{R} <: RandomNumbers.AbstractRNG{UInt64}
70+
@inline function Philox2x32{R}() where R
71+
return new{R}()
72+
end
73+
end
74+
75+
# default to 7 rounds; enough to pass SmallCrush
76+
@inline Philox2x32() = Philox2x32{7}()
77+
78+
@inline function Base.getproperty(rng::Philox2x32, field::Symbol)
79+
threadId = thread_position_in_threadgroup_1d()
80+
warpId = (threadId - Int32(1)) >> 0x5 + Int32(1) # fld1 by 32
81+
82+
if field === :seed
83+
@inbounds global_random_seed()[1]
84+
elseif field === :key
85+
@inbounds global_random_keys()[warpId]
86+
elseif field === :ctr1
87+
@inbounds global_random_counters()[warpId]
88+
elseif field === :ctr2
89+
globalId = thread_position_in_grid_1d()
90+
globalId % UInt32
91+
end::UInt32
92+
end
93+
94+
@inline function Base.setproperty!(rng::Philox2x32, field::Symbol, x)
95+
threadId = thread_position_in_threadgroup_1d()
96+
warpId = (threadId - Int32(1)) >> 0x5 + Int32(1) # fld1 by 32
97+
98+
if field === :key
99+
@inbounds global_random_keys()[warpId] = x
100+
elseif field === :ctr1
101+
@inbounds global_random_counters()[warpId] = x
102+
end
103+
end
104+
105+
@device_override @inline Random.default_rng() = Philox2x32()
106+
107+
"""
108+
Random.seed!(rng::Philox2x32, seed::Integer, [counter::Integer=0])
109+
110+
Seed the on-device Philox2x32 generator with an UInt32 number.
111+
Should be called by at least one thread per warp.
112+
"""
113+
function Random.seed!(rng::Philox2x32, seed::Integer, counter::Integer=UInt32(0))
114+
rng.key = seed % UInt32
115+
rng.ctr1 = counter
116+
return
117+
end
118+
119+
# seeding the implicit default RNG
120+
@static if VERSION >= v"1.11-"
121+
@device_override Random.seed!(seed) =
122+
Random.seed!(Random.default_rng(), seed)
123+
else
124+
@device_override Random.seed!(::Random._GLOBAL_RNG, seed) =
125+
Random.seed!(Random.default_rng(), seed)
126+
end
127+
128+
"""
129+
Random.rand(rng::Philox2x32, UInt32)
130+
131+
Generate a byte of random data using the on-device Tausworthe generator.
132+
"""
133+
function Random.rand(rng::Philox2x32{R},::Type{UInt64}) where {R}
134+
ctr1, ctr2, key = rng.ctr1, rng.ctr2, rng.key
135+
136+
if R > 0 ctr1, ctr2 = philox2x_round(ctr1, ctr2, key); end
137+
if R > 1 key = philox2x_bumpkey(key); ctr1, ctr2 = philox2x_round(ctr1, ctr2, key); end
138+
if R > 2 key = philox2x_bumpkey(key); ctr1, ctr2 = philox2x_round(ctr1, ctr2, key); end
139+
if R > 3 key = philox2x_bumpkey(key); ctr1, ctr2 = philox2x_round(ctr1, ctr2, key); end
140+
if R > 4 key = philox2x_bumpkey(key); ctr1, ctr2 = philox2x_round(ctr1, ctr2, key); end
141+
if R > 5 key = philox2x_bumpkey(key); ctr1, ctr2 = philox2x_round(ctr1, ctr2, key); end
142+
if R > 6 key = philox2x_bumpkey(key); ctr1, ctr2 = philox2x_round(ctr1, ctr2, key); end
143+
if R > 7 key = philox2x_bumpkey(key); ctr1, ctr2 = philox2x_round(ctr1, ctr2, key); end
144+
if R > 8 key = philox2x_bumpkey(key); ctr1, ctr2 = philox2x_round(ctr1, ctr2, key); end
145+
if R > 9 key = philox2x_bumpkey(key); ctr1, ctr2 = philox2x_round(ctr1, ctr2, key); end
146+
if R > 10 key = philox2x_bumpkey(key); ctr1, ctr2 = philox2x_round(ctr1, ctr2, key); end
147+
if R > 11 key = philox2x_bumpkey(key); ctr1, ctr2 = philox2x_round(ctr1, ctr2, key); end
148+
if R > 12 key = philox2x_bumpkey(key); ctr1, ctr2 = philox2x_round(ctr1, ctr2, key); end
149+
if R > 13 key = philox2x_bumpkey(key); ctr1, ctr2 = philox2x_round(ctr1, ctr2, key); end
150+
if R > 14 key = philox2x_bumpkey(key); ctr1, ctr2 = philox2x_round(ctr1, ctr2, key); end
151+
if R > 15 key = philox2x_bumpkey(key); ctr1, ctr2 = philox2x_round(ctr1, ctr2, key); end
152+
153+
# update the warp counter
154+
# NOTE: this performs the same update on every thread in the warp, but each warp writes
155+
# to a unique location so the duplicate writes are innocuous
156+
# XXX: what if this overflows? we can't increment ctr2. bump the key?
157+
rng.ctr1 += Int32(1)
158+
159+
# NOTE: it's too expensive to keep both numbers around in case the user only wanted one,
160+
# so just make our 2x32 generator return 64-bit numbers by default.
161+
return (ctr1 % UInt64) << 32 | (ctr2 % UInt64)
162+
end
163+
164+
165+
# normally distributed
166+
167+
# use the AbstractFloat fallback from Base, which doesn't widen and only relies on `rand()`.
168+
# the Ziggurat method used by other back-ends relies on Float64 support.
169+
@device_override @inline function Random.randn(rng::Philox2x32, ::Type{T}) where {T <: AbstractFloat}
170+
@invoke Random.randn(rng::AbstractRNG, T::Type{<:AbstractFloat})
171+
end
172+
173+
174+
# exponentially distributed
175+
176+
# use the AbstractFloat fallback from Base, which doesn't widen and only relies on `rand()`.
177+
# the Ziggurat method used by other back-ends relies on Float64 support.
178+
@device_override @inline function Random.randexp(rng::Philox2x32, ::Type{T}) where {T <: AbstractFloat}
179+
@invoke Random.randexp(rng::AbstractRNG, T::Type{<:AbstractFloat})
180+
end

src/device/runtime.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,3 +32,11 @@ function report_exception_frame(idx, func, file, line)
3232
# @cuprintf(" [%i] %s at %s:%i\n", idx, func, file, line)
3333
return
3434
end
35+
36+
## kernel state
37+
38+
struct KernelState
39+
random_seed::UInt32
40+
end
41+
42+
@inline @generated kernel_state() = GPUCompiler.kernel_state_value(KernelState)

test/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
55
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
66
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
77
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
8+
GPUCompiler = "61eb1bfa-7361-4325-ad38-22787b887f55"
89
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
910
JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
1011
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"

0 commit comments

Comments
 (0)