134134
135135 # Poisson sampling
136136 @inbounds for k in 1 : num_jumps
137- counts[k] = pois_rand (rate_cache[k])
137+ counts[k] = pois_rand (PassthroughRNG (), rate_cache[k])
138138 end
139139
140140 # Apply changes
@@ -214,12 +214,17 @@ end
214214# GPU-compatible Poisson sampling PassthroughRNG
215215struct PassthroughRNG <: AbstractRNG end
216216
217+ rand (rng:: PassthroughRNG ) = Random. rand ()
218+ randexp (rng:: PassthroughRNG ) = Random. randexp ()
219+ randn (rng:: PassthroughRNG ) = Random. randn ()
220+
221+ count_rand (λ) = count_rand (Random. GLOBAL_RNG, λ)
217222function count_rand (rng:: AbstractRNG , λ)
218223 n = 0
219- c = rng isa PassthroughRNG ? randexp () : randexp (rng)
224+ c = randexp (rng)
220225 while c < λ
221226 n += 1
222- c += rng isa PassthroughRNG ? randexp () : randexp (rng)
227+ c += randexp (rng)
223228 end
224229 return n
225230end
@@ -232,43 +237,50 @@ end
232237#
233238# For μ sufficiently large, (i.e. >= 10.0)
234239#
240+ ad_rand (λ) = ad_rand (Random. GLOBAL_RNG, λ)
235241function ad_rand (rng:: AbstractRNG , λ)
236- λ = Float64 (λ)
237242 s = sqrt (λ)
238- d = 6.0 * λ^ 2
243+ d = 6 * λ^ 2
239244 L = floor (Int, λ - 1.1484 )
240-
241- G = λ + s * (rng isa PassthroughRNG ? randn () : randn ( rng) )
245+ # Step N
246+ G = λ + s * randn (rng)
242247
243248 if G >= 0
244249 K = floor (Int, G)
250+ # Step I
245251 if K >= L
246252 return K
247253 end
248254
249- U = rng isa PassthroughRNG ? rand () : rand (rng)
255+ # Step S
256+ U = rand (rng)
250257 if d * U >= (λ - K)^ 3
251258 return K
252259 end
253260
261+ # Step P
254262 px, py, fx, fy = procf (λ, K, s)
263+
264+ # Step Q
255265 if fy * (1 - U) <= py * exp (px - fx)
256266 return K
257267 end
258268 end
259269
260270 while true
261- E = rng isa PassthroughRNG ? randexp () : randexp (rng)
262- U = 2 * (rng isa PassthroughRNG ? rand () : rand (rng)) - 1
263- T_val = 1.8 + copysign (E, U)
264- if T_val <= - 0.6744
271+ # Step E
272+ E = randexp (rng)
273+ U = 2 * rand (rng) - 1
274+ T = 1.8 + copysign (E, U)
275+ if T <= - 0.6744
265276 continue
266277 end
267278
268- K = floor (Int, λ + s * T_val )
279+ K = floor (Int, λ + s * T )
269280 px, py, fx, fy = procf (λ, K, s)
270281 c = 0.1069 / λ
271282
283+ # Step H
272284 @fastmath if c * abs (U) <= py * exp (px + E) - fy * exp (fx + E)
273285 return K
274286 end
@@ -277,30 +289,30 @@ end
277289
278290# Procedure F
279291function procf (λ, K:: Int , s:: Float64 )
280- INV_SQRT_2PI = 0.3989422804014327 # 1/sqrt(2π)
292+ # can be pre-computed, but does not seem to affect performance
293+ INV_SQRT_2PI = inv (sqrt (2pi ))
281294 ω = INV_SQRT_2PI / s
282- b1 = 1 / (24 * λ)
283- b2 = 0.3 * b1^ 2
284- c3 = b1 * b2 / 7
295+ b1 = inv (24 ) / λ
296+ b2 = 0.3 * b1 * b1
297+ c3 = inv ( 7 ) * b1 * b2
285298 c2 = b2 - 15 * c3
286299 c1 = b1 - 6 * b2 + 45 * c3
287300 c0 = 1 - b1 + 3 * b2 - 15 * c3
288301
289302 if K < 10
290303 px = - λ
291- log_py = K * log (λ) - loggamma (K + 1 ) # log(K!) via loggamma
304+ log_py = K * log (λ) - loggamma (K + 1 ) # log(K!) via loggamma
292305 py = exp (log_py)
293306 else
294- δ = 1 / (12 * K)
307+ δ = inv (12 ) / K
295308 δ -= 4.8 * δ^ 3
296309 V = (λ - K) / K
297- px = K * log1pmx (V) - δ
310+ px = K * log1pmx (V) - δ # avoids need for table
298311 py = INV_SQRT_2PI / sqrt (K)
299312 end
300-
301313 X = (K - λ + 0.5 ) / s
302314 X2 = X^ 2
303- fx = - X2 / 2
315+ fx = X2 / - 2 # missing negation in pseudo-algorithm, but appears in fortran code.
304316 fy = ω * (((c3 * X2 + c2) * X2 + c1) * X2 + c0)
305317 return px, py, fx, fy
306318end
@@ -316,16 +328,19 @@ Generates Poisson(λ) distributed random numbers using a fast polyalgorithm.
316328## Examples
317329
318330```julia
319- # Simple Poisson random which works on GPU
331+ # Simple Poisson random
320332pois_rand(λ)
321333
322- # Using RNG
334+ # Using another RNG
323335using RandomNumbers
324336rng = Xorshifts.Xoroshiro128Plus()
325337pois_rand(rng, λ)
338+
339+ # Simple Poisson random on GPU
340+ pois_rand(PassthroughRNG(), λ)
326341```
327342"""
328- pois_rand (λ) = λ < 6 ? count_rand ( PassthroughRNG (), λ) : ad_rand ( PassthroughRNG () , λ)
343+ pois_rand (λ) = pois_rand (Random . GLOBAL_RNG , λ)
329344pois_rand (rng:: AbstractRNG , λ) = λ < 6 ? count_rand (rng, λ) : ad_rand (rng, λ)
330345
331346end
0 commit comments