Skip to content

Commit 9796d5a

Browse files
authored
Merge pull request #2035 from JuliaGPU/tb/rand_seed
rand: seed kernels from the host.
2 parents fade845 + 556b23e commit 9796d5a

File tree

6 files changed

+102
-31
lines changed

6 files changed

+102
-31
lines changed

.buildkite/pipeline.yml

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ steps:
1919
cuda: "*"
2020
commands: |
2121
julia --project -e '
22-
# make sure the 1.6-era Manifest works on this Julia version
22+
# make sure the 1.7-era Manifest works on this Julia version
2323
using Pkg
2424
Pkg.resolve()
2525
@@ -32,7 +32,6 @@ steps:
3232
matrix:
3333
setup:
3434
julia:
35-
- "1.6"
3635
- "1.7"
3736
- "1.8"
3837
- "1.9"
@@ -315,10 +314,10 @@ steps:
315314
matrix:
316315
setup:
317316
julia:
318-
- "1.6"
319317
- "1.7"
320318
- "1.8"
321319
- "1.9"
320+
- "1.10"
322321
- "nightly"
323322
adjustments:
324323
- with:

src/compiler/compilation.jl

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,68 @@ GPUCompiler.method_table(@nospecialize(job::CUDACompilerJob)) = method_table
4242

4343
GPUCompiler.kernel_state_type(job::CUDACompilerJob) = KernelState
4444

45+
function GPUCompiler.finish_module!(@nospecialize(job::CUDACompilerJob),
46+
mod::LLVM.Module, entry::LLVM.Function)
47+
entry = invoke(GPUCompiler.finish_module!,
48+
Tuple{CompilerJob{PTXCompilerTarget}, LLVM.Module, LLVM.Function},
49+
job, mod, entry)
50+
51+
# if this kernel uses our RNG, we should prime the shared state.
52+
# XXX: these transformations should really happen at the Julia IR level...
53+
if haskey(globals(mod), "global_random_keys")
54+
f = initialize_rng_state
55+
ft = typeof(f)
56+
tt = Tuple{}
57+
58+
# don't recurse into `initialize_rng_state()` itself
59+
if job.source.specTypes.parameters[1] == ft
60+
return entry
61+
end
62+
63+
# create a deferred compilation job for `initialize_rng_state()`
64+
src = methodinstance(ft, tt, GPUCompiler.tls_world_age())
65+
cfg = CompilerConfig(job.config; kernel=false, name=nothing)
66+
job = CompilerJob(src, cfg, job.world)
67+
id = length(GPUCompiler.deferred_codegen_jobs) + 1
68+
GPUCompiler.deferred_codegen_jobs[id] = job
69+
70+
# generate IR for calls to `deferred_codegen` and the resulting function pointer
71+
top_bb = first(blocks(entry))
72+
bb = BasicBlock(top_bb, "initialize_rng")
73+
LLVM.@dispose builder=IRBuilder() begin
74+
position!(builder, bb)
75+
subprogram = LLVM.get_subprogram(entry)
76+
if subprogram !== nothing
77+
loc = DILocation(0, 0, subprogram)
78+
debuglocation!(builder, loc)
79+
end
80+
debuglocation!(builder, first(instructions(top_bb)))
81+
82+
# call the `deferred_codegen` marker function
83+
T_ptr = LLVM.Int64Type()
84+
deferred_codegen_ft = LLVM.FunctionType(T_ptr, [T_ptr])
85+
deferred_codegen = if haskey(functions(mod), "deferred_codegen")
86+
functions(mod)["deferred_codegen"]
87+
else
88+
LLVM.Function(mod, "deferred_codegen", deferred_codegen_ft)
89+
end
90+
fptr = call!(builder, deferred_codegen_ft, deferred_codegen, [ConstantInt(id)])
91+
92+
# call the `initialize_rng_state` function
93+
rt = Core.Compiler.return_type(f, tt)
94+
llvm_rt = convert(LLVMType, rt)
95+
llvm_ft = LLVM.FunctionType(llvm_rt)
96+
fptr = inttoptr!(builder, fptr, LLVM.PointerType(llvm_ft))
97+
call!(builder, llvm_ft, fptr)
98+
br!(builder, top_bb)
99+
end
100+
101+
# XXX: put some of the above behind GPUCompiler abstractions
102+
# (e.g., a compile-time version of `deferred_codegen`)
103+
end
104+
return entry
105+
end
106+
45107

46108
## compiler implementation (cache, configure, compile, and link)
47109

src/compiler/execution.jl

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -212,9 +212,9 @@ end
212212
end
213213
end
214214

215-
# add the kernel state
215+
# add the kernel state, passing an instance with a unique seed
216216
pushfirst!(call_t, KernelState)
217-
pushfirst!(call_args, :(kernel.state))
217+
pushfirst!(call_args, :(KernelState(kernel.state.exception_flag, make_seed(kernel))))
218218

219219
# finalize types
220220
call_tt = Base.to_tuple_type(call_t)
@@ -329,7 +329,7 @@ function cufunction(f::F, tt::TT=Tuple{}; kwargs...) where {F,TT}
329329
if kernel === nothing
330330
# create the kernel state object
331331
exception_ptr = create_exceptions!(fun.mod)
332-
state = KernelState(exception_ptr)
332+
state = KernelState(exception_ptr, UInt32(0))
333333

334334
kernel = HostKernel{F,tt}(f, fun, state)
335335
_kernel_instances[key] = kernel
@@ -345,6 +345,8 @@ function (kernel::HostKernel)(args...; threads::CuDim=1, blocks::CuDim=1, kwargs
345345
call(kernel, map(cudaconvert, args)...; threads, blocks, kwargs...)
346346
end
347347

348+
make_seed(::HostKernel) = Random.rand(UInt32)
349+
348350

349351
## device-side kernels
350352

@@ -375,6 +377,9 @@ end
375377

376378
(kernel::DeviceKernel)(args...; kwargs...) = call(kernel, args...; kwargs...)
377379

380+
# re-use the parent kernel's seed to avoid need for the RNG
381+
make_seed(::DeviceKernel) = kernel_state().random_seed
382+
378383

379384
## other
380385

src/device/random.jl

Lines changed: 21 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,12 @@ import RandomNumbers
66

77
# global state
88

9-
# shared memory with the actual seed, per warp, loaded lazily or overridden by calling `seed!`
9+
# we cannot store RNG state in thread-local memory (i.e. in the `rng` object) because that
10+
# inflate register usage. instead, we store it in shared memory, with one entry per warp.
11+
#
12+
# XXX: this implies that state is shared between `rng` objects, which can be surprising.
13+
14+
# array with seeds, per warp, initialized on kernel start or by calling `seed!`
1015
@eval @inline function global_random_keys()
1116
ptr = Base.llvmcall(
1217
$("""@global_random_keys = weak addrspace($(AS.Shared)) global [32 x i32] zeroinitializer, align 32
@@ -20,7 +25,7 @@ import RandomNumbers
2025
CuDeviceArray{UInt32,1,AS.Shared}(ptr, (32,))
2126
end
2227

23-
# shared memory with per-warp counters, incremented when generating numbers
28+
# array with per-warp counters, incremented when generating numbers
2429
@eval @inline function global_random_counters()
2530
ptr = Base.llvmcall(
2631
$("""@global_random_counters = weak addrspace($(AS.Shared)) global [32 x i32] zeroinitializer, align 32
@@ -34,6 +39,17 @@ end
3439
CuDeviceArray{UInt32,1,AS.Shared}(ptr, (32,))
3540
end
3641

42+
# initialization function, called automatically at the start of each kernel because
43+
# there's no reliable way to detect uninitialized shared memory (see JuliaGPU/CUDA.jl#2008)
44+
function initialize_rng_state()
45+
threadId = threadIdx().x + (threadIdx().y - 1i32) * blockDim().x +
46+
(threadIdx().z - 1i32) * blockDim().x * blockDim().y
47+
warpId = (threadId - 1i32) >> 0x5 + 1i32 # fld1
48+
49+
@inbounds global_random_keys()[warpId] = kernel_state().random_seed
50+
@inbounds global_random_counters()[warpId] = 0
51+
end
52+
3753
@device_override Random.make_seed() = clock(UInt32)
3854

3955

@@ -43,19 +59,7 @@ using Random123: philox2x_round, philox2x_bumpkey
4359

4460
# GPU-compatible/optimized version of the generator from Random123.jl
4561
struct Philox2x32{R} <: RandomNumbers.AbstractRNG{UInt64}
46-
@inline function Philox2x32{R}() where R
47-
rng = new{R}()
48-
if rng.key == 0
49-
# initialize the key. this happens when first accessing the (0-initialized)
50-
# shared memory key from each block. if we ever want to make the device seed
51-
# controlable from the host, this would be the place to read a global seed.
52-
#
53-
# note however that it is undefined how shared memory persists across e.g.
54-
# launches, so we may not be able to rely on the zero initalization then.
55-
rng.key = Random.make_seed()
56-
end
57-
return rng
58-
end
62+
# NOTE: the state is stored globally; see comments at the top of this file.
5963
end
6064

6165
# default to 7 rounds; enough to pass SmallCrush
@@ -66,9 +70,7 @@ end
6670
(threadIdx().z - 1i32) * blockDim().x * blockDim().y
6771
warpId = (threadId - 1i32) >> 0x5 + 1i32 # fld1
6872

69-
if field === :seed
70-
@inbounds global_random_seed()[1]
71-
elseif field === :key
73+
if field === :key
7274
@inbounds global_random_keys()[warpId]
7375
elseif field === :ctr1
7476
@inbounds global_random_counters()[warpId]
@@ -139,6 +141,7 @@ function Random.rand(rng::Philox2x32{R},::Type{UInt64}) where {R}
139141
# update the warp counter
140142
# NOTE: this performs the same update on every thread in the warp, but each warp writes
141143
# to a unique location so the duplicate writes are innocuous
144+
# NOTE: this is not guaranteed to be visible in other kernels (JuliaGPU/CUDA.jl#2008)
142145
# XXX: what if this overflows? we can't increment ctr2. bump the key?
143146
rng.ctr1 += 1i32
144147

src/device/runtime.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,14 +28,13 @@ end
2828

2929
struct KernelState
3030
exception_flag::Ptr{Cvoid}
31+
random_seed::UInt32
3132
end
3233

3334
@inline @generated kernel_state() = GPUCompiler.kernel_state_value(KernelState)
3435

35-
exception_flag() = kernel_state().exception_flag
36-
3736
function signal_exception()
38-
ptr = exception_flag()
37+
ptr = kernel_state().exception_flag
3938
if ptr !== C_NULL
4039
unsafe_store!(convert(Ptr{Int}, ptr), 1)
4140
threadfence_system()

src/sorting.jl

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -459,8 +459,11 @@ end
459459

460460
#function sort
461461

462-
function quicksort!(c::AbstractArray{T,N}; lt::F1, by::F2, dims::Int, partial_k=nothing, block_size_shift=0) where {T,N,F1,F2}
463-
max_depth = CUDA.limit(CUDA.LIMIT_DEV_RUNTIME_SYNC_DEPTH)
462+
function quicksort!(c::AbstractArray{T,N}; lt::F1, by::F2, dims::Int, partial_k=nothing,
463+
block_size_shift=0) where {T,N,F1,F2}
464+
# XXX: after JuliaLang/CUDA.jl#2035, which changed the kernel state struct contents,
465+
# the max depth needed to be reduced by 1 to avoid an illegal memory crash...
466+
max_depth = CUDA.limit(CUDA.LIMIT_DEV_RUNTIME_SYNC_DEPTH) - 1
464467
len = size(c, dims)
465468

466469
1 <= dims <= N || throw(ArgumentError("dimension out of range"))
@@ -884,11 +887,11 @@ function bitonic_sort!(c; by = identity, lt = isless, rev = false)
884887
# N_pseudo_blocks = how many pseudo-blocks are in this layer of the network
885888
N_pseudo_blocks = nextpow(2, c_len) ÷ pseudo_block_length
886889
pseudo_blocks_per_block = threads2 ÷ pseudo_block_length
887-
890+
888891
# grid dimensions
889892
N_blocks = max(1, N_pseudo_blocks ÷ pseudo_blocks_per_block)
890893
block_size = pseudo_block_length, threads2 ÷ pseudo_block_length
891-
894+
892895
kernel1(args1...; blocks=N_blocks, threads=block_size,
893896
shmem=bitonic_shmem(c, block_size))
894897
break

0 commit comments

Comments
 (0)