|
| 1 | +## random number generation |
| 2 | + |
| 3 | +using Random |
| 4 | +import RandomNumbers |
| 5 | + |
| 6 | + |
| 7 | +# global state |
| 8 | + |
| 9 | +@inline @generated function emit_global_random_values(::Val{name}) where name |
| 10 | + @dispose ctx=Context() begin |
| 11 | + T_val = convert(LLVMType, UInt32) |
| 12 | + T_ptr = convert(LLVMType, LLVMPtr{UInt32,AS.ThreadGroup}) |
| 13 | + |
| 14 | + # define function and get LLVM module |
| 15 | + llvm_f, _ = create_function(T_ptr) |
| 16 | + mod = LLVM.parent(llvm_f) |
| 17 | + |
| 18 | + # create a global memory global variable |
| 19 | + T_global = LLVM.ArrayType(T_val, 32) |
| 20 | + gv = GlobalVariable(mod, T_global, "global_random_$(name)", AS.ThreadGroup) |
| 21 | + linkage!(gv, LLVM.API.LLVMLinkOnceAnyLinkage) |
| 22 | + initializer!(gv, LLVM.null(T_global)) |
| 23 | + unnamed_addr!(gv, true) |
| 24 | + alignment!(gv, 4) |
| 25 | + |
| 26 | + # generate IR |
| 27 | + @dispose builder=IRBuilder() begin |
| 28 | + entry = BasicBlock(llvm_f, "entry") |
| 29 | + position!(builder, entry) |
| 30 | + |
| 31 | + ptr = gep!(builder, T_global, gv, [ConstantInt(0), ConstantInt(0)]) |
| 32 | + |
| 33 | + untyped_ptr = bitcast!(builder, ptr, T_ptr) |
| 34 | + |
| 35 | + ret!(builder, untyped_ptr) |
| 36 | + end |
| 37 | + |
| 38 | + call_function(llvm_f, LLVMPtr{UInt32,AS.ThreadGroup}) |
| 39 | + end |
| 40 | +end |
| 41 | + |
| 42 | +# shared memory with the actual seed, per warp, loaded lazily or overridden by calling `seed!` |
| 43 | +@inline function global_random_keys() |
| 44 | + ptr = emit_global_random_values(Val{:keys}())::LLVMPtr{UInt32,AS.ThreadGroup} |
| 45 | + return MtlDeviceArray{UInt32,1,AS.ThreadGroup}((32,), ptr) |
| 46 | +end |
| 47 | + |
| 48 | +# shared memory with per-warp counters, incremented when generating numbers |
| 49 | +@inline function global_random_counters() |
| 50 | + ptr = emit_global_random_values(Val{:counters}())::LLVMPtr{UInt32,AS.ThreadGroup} |
| 51 | + return MtlDeviceArray{UInt32,1,AS.ThreadGroup}((32,), ptr) |
| 52 | +end |
| 53 | + |
| 54 | +# initialization function, called automatically at the start of each kernel because |
| 55 | +# there's no reliable way to detect uninitialized shared memory (see JuliaGPU/CUDA.jl#2008) |
| 56 | +function initialize_rng_state() |
| 57 | + threadId = thread_position_in_threadgroup_1d() |
| 58 | + warpId = (threadId - Int32(1)) >> 0x5 + Int32(1) # fld1 by 32 |
| 59 | + |
| 60 | + @inbounds global_random_keys()[warpId] = kernel_state().random_seed |
| 61 | + @inbounds global_random_counters()[warpId] = 0 |
| 62 | +end |
| 63 | + |
| 64 | +# generators |
| 65 | + |
| 66 | +using Random123: philox2x_round, philox2x_bumpkey |
| 67 | + |
| 68 | +# GPU-compatible/optimized version of the generator from Random123.jl |
| 69 | +struct Philox2x32{R} <: RandomNumbers.AbstractRNG{UInt64} |
| 70 | + @inline function Philox2x32{R}() where R |
| 71 | + return new{R}() |
| 72 | + end |
| 73 | +end |
| 74 | + |
| 75 | +# default to 7 rounds; enough to pass SmallCrush |
| 76 | +@inline Philox2x32() = Philox2x32{7}() |
| 77 | + |
| 78 | +@inline function Base.getproperty(rng::Philox2x32, field::Symbol) |
| 79 | + threadId = thread_position_in_threadgroup_1d() |
| 80 | + warpId = (threadId - Int32(1)) >> 0x5 + Int32(1) # fld1 by 32 |
| 81 | + |
| 82 | + if field === :seed |
| 83 | + @inbounds global_random_seed()[1] |
| 84 | + elseif field === :key |
| 85 | + @inbounds global_random_keys()[warpId] |
| 86 | + elseif field === :ctr1 |
| 87 | + @inbounds global_random_counters()[warpId] |
| 88 | + elseif field === :ctr2 |
| 89 | + globalId = thread_position_in_grid_1d() |
| 90 | + globalId % UInt32 |
| 91 | + end::UInt32 |
| 92 | +end |
| 93 | + |
| 94 | +@inline function Base.setproperty!(rng::Philox2x32, field::Symbol, x) |
| 95 | + threadId = thread_position_in_threadgroup_1d() |
| 96 | + warpId = (threadId - Int32(1)) >> 0x5 + Int32(1) # fld1 by 32 |
| 97 | + |
| 98 | + if field === :key |
| 99 | + @inbounds global_random_keys()[warpId] = x |
| 100 | + elseif field === :ctr1 |
| 101 | + @inbounds global_random_counters()[warpId] = x |
| 102 | + end |
| 103 | +end |
| 104 | + |
| 105 | +@device_override @inline Random.default_rng() = Philox2x32() |
| 106 | + |
| 107 | +""" |
| 108 | + Random.seed!(rng::Philox2x32, seed::Integer, [counter::Integer=0]) |
| 109 | +
|
| 110 | +Seed the on-device Philox2x32 generator with an UInt32 number. |
| 111 | +Should be called by at least one thread per warp. |
| 112 | +""" |
| 113 | +function Random.seed!(rng::Philox2x32, seed::Integer, counter::Integer=UInt32(0)) |
| 114 | + rng.key = seed % UInt32 |
| 115 | + rng.ctr1 = counter |
| 116 | + return |
| 117 | +end |
| 118 | + |
| 119 | +# seeding the implicit default RNG |
| 120 | +@static if VERSION >= v"1.11-" |
| 121 | + @device_override Random.seed!(seed) = |
| 122 | + Random.seed!(Random.default_rng(), seed) |
| 123 | +else |
| 124 | + @device_override Random.seed!(::Random._GLOBAL_RNG, seed) = |
| 125 | + Random.seed!(Random.default_rng(), seed) |
| 126 | +end |
| 127 | + |
| 128 | +""" |
| 129 | + Random.rand(rng::Philox2x32, UInt32) |
| 130 | +
|
| 131 | +Generate a byte of random data using the on-device Tausworthe generator. |
| 132 | +""" |
| 133 | +function Random.rand(rng::Philox2x32{R},::Type{UInt64}) where {R} |
| 134 | + ctr1, ctr2, key = rng.ctr1, rng.ctr2, rng.key |
| 135 | + |
| 136 | + if R > 0 ctr1, ctr2 = philox2x_round(ctr1, ctr2, key); end |
| 137 | + if R > 1 key = philox2x_bumpkey(key); ctr1, ctr2 = philox2x_round(ctr1, ctr2, key); end |
| 138 | + if R > 2 key = philox2x_bumpkey(key); ctr1, ctr2 = philox2x_round(ctr1, ctr2, key); end |
| 139 | + if R > 3 key = philox2x_bumpkey(key); ctr1, ctr2 = philox2x_round(ctr1, ctr2, key); end |
| 140 | + if R > 4 key = philox2x_bumpkey(key); ctr1, ctr2 = philox2x_round(ctr1, ctr2, key); end |
| 141 | + if R > 5 key = philox2x_bumpkey(key); ctr1, ctr2 = philox2x_round(ctr1, ctr2, key); end |
| 142 | + if R > 6 key = philox2x_bumpkey(key); ctr1, ctr2 = philox2x_round(ctr1, ctr2, key); end |
| 143 | + if R > 7 key = philox2x_bumpkey(key); ctr1, ctr2 = philox2x_round(ctr1, ctr2, key); end |
| 144 | + if R > 8 key = philox2x_bumpkey(key); ctr1, ctr2 = philox2x_round(ctr1, ctr2, key); end |
| 145 | + if R > 9 key = philox2x_bumpkey(key); ctr1, ctr2 = philox2x_round(ctr1, ctr2, key); end |
| 146 | + if R > 10 key = philox2x_bumpkey(key); ctr1, ctr2 = philox2x_round(ctr1, ctr2, key); end |
| 147 | + if R > 11 key = philox2x_bumpkey(key); ctr1, ctr2 = philox2x_round(ctr1, ctr2, key); end |
| 148 | + if R > 12 key = philox2x_bumpkey(key); ctr1, ctr2 = philox2x_round(ctr1, ctr2, key); end |
| 149 | + if R > 13 key = philox2x_bumpkey(key); ctr1, ctr2 = philox2x_round(ctr1, ctr2, key); end |
| 150 | + if R > 14 key = philox2x_bumpkey(key); ctr1, ctr2 = philox2x_round(ctr1, ctr2, key); end |
| 151 | + if R > 15 key = philox2x_bumpkey(key); ctr1, ctr2 = philox2x_round(ctr1, ctr2, key); end |
| 152 | + |
| 153 | + # update the warp counter |
| 154 | + # NOTE: this performs the same update on every thread in the warp, but each warp writes |
| 155 | + # to a unique location so the duplicate writes are innocuous |
| 156 | + # XXX: what if this overflows? we can't increment ctr2. bump the key? |
| 157 | + rng.ctr1 += Int32(1) |
| 158 | + |
| 159 | + # NOTE: it's too expensive to keep both numbers around in case the user only wanted one, |
| 160 | + # so just make our 2x32 generator return 64-bit numbers by default. |
| 161 | + return (ctr1 % UInt64) << 32 | (ctr2 % UInt64) |
| 162 | +end |
| 163 | + |
| 164 | + |
| 165 | +# normally distributed |
| 166 | + |
| 167 | +# use the AbstractFloat fallback from Base, which doesn't widen and only relies on `rand()`. |
| 168 | +# the Ziggurat method used by other back-ends relies on Float64 support. |
| 169 | +@device_override @inline function Random.randn(rng::Philox2x32, ::Type{T}) where {T <: AbstractFloat} |
| 170 | + @invoke Random.randn(rng::AbstractRNG, T::Type{<:AbstractFloat}) |
| 171 | +end |
| 172 | + |
| 173 | + |
| 174 | +# exponentially distributed |
| 175 | + |
| 176 | +# use the AbstractFloat fallback from Base, which doesn't widen and only relies on `rand()`. |
| 177 | +# the Ziggurat method used by other back-ends relies on Float64 support. |
| 178 | +@device_override @inline function Random.randexp(rng::Philox2x32, ::Type{T}) where {T <: AbstractFloat} |
| 179 | + @invoke Random.randexp(rng::AbstractRNG, T::Type{<:AbstractFloat}) |
| 180 | +end |
0 commit comments