Skip to content

Commit b3c76d6

Browse files
pois_rand optimization
1 parent 8a41d58 commit b3c76d6

File tree

1 file changed

+40
-25
lines changed

1 file changed

+40
-25
lines changed

ext/JumpProcessesKernelAbstractionsExt.jl

Lines changed: 40 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ end
134134

135135
# Poisson sampling
136136
@inbounds for k in 1:num_jumps
137-
counts[k] = pois_rand(rate_cache[k])
137+
counts[k] = pois_rand(PassthroughRNG(), rate_cache[k])
138138
end
139139

140140
# Apply changes
@@ -214,12 +214,17 @@ end
214214
# GPU-compatible Poisson sampling PassthroughRNG
215215
struct PassthroughRNG <: AbstractRNG end
216216

217+
rand(rng::PassthroughRNG) = Random.rand()
218+
randexp(rng::PassthroughRNG) = Random.randexp()
219+
randn(rng::PassthroughRNG) = Random.randn()
220+
221+
count_rand(λ) = count_rand(Random.GLOBAL_RNG, λ)
217222
function count_rand(rng::AbstractRNG, λ)
218223
n = 0
219-
c = rng isa PassthroughRNG ? randexp() : randexp(rng)
224+
c = randexp(rng)
220225
while c < λ
221226
n += 1
222-
c += rng isa PassthroughRNG ? randexp() : randexp(rng)
227+
c += randexp(rng)
223228
end
224229
return n
225230
end
@@ -232,43 +237,50 @@ end
232237
#
233238
# For μ sufficiently large, (i.e. >= 10.0)
234239
#
240+
ad_rand(λ) = ad_rand(Random.GLOBAL_RNG, λ)
235241
function ad_rand(rng::AbstractRNG, λ)
236-
λ = Float64(λ)
237242
s = sqrt(λ)
238-
d = 6.0 * λ^2
243+
d = 6 * λ^2
239244
L = floor(Int, λ - 1.1484)
240-
241-
G = λ + s * (rng isa PassthroughRNG ? randn() : randn(rng))
245+
# Step N
246+
G = λ + s * randn(rng)
242247

243248
if G >= 0
244249
K = floor(Int, G)
250+
# Step I
245251
if K >= L
246252
return K
247253
end
248254

249-
U = rng isa PassthroughRNG ? rand() : rand(rng)
255+
# Step S
256+
U = rand(rng)
250257
if d * U >=- K)^3
251258
return K
252259
end
253260

261+
# Step P
254262
px, py, fx, fy = procf(λ, K, s)
263+
264+
# Step Q
255265
if fy * (1 - U) <= py * exp(px - fx)
256266
return K
257267
end
258268
end
259269

260270
while true
261-
E = rng isa PassthroughRNG ? randexp() : randexp(rng)
262-
U = 2 * (rng isa PassthroughRNG ? rand() : rand(rng)) - 1
263-
T_val = 1.8 + copysign(E, U)
264-
if T_val <= -0.6744
271+
# Step E
272+
E = randexp(rng)
273+
U = 2 * rand(rng) - 1
274+
T = 1.8 + copysign(E, U)
275+
if T <= -0.6744
265276
continue
266277
end
267278

268-
K = floor(Int, λ + s * T_val)
279+
K = floor(Int, λ + s * T)
269280
px, py, fx, fy = procf(λ, K, s)
270281
c = 0.1069 / λ
271282

283+
# Step H
272284
@fastmath if c * abs(U) <= py * exp(px + E) - fy * exp(fx + E)
273285
return K
274286
end
@@ -277,30 +289,30 @@ end
277289

278290
# Procedure F
279291
function procf(λ, K::Int, s::Float64)
280-
INV_SQRT_2PI = 0.3989422804014327 # 1/sqrt(2π)
292+
# can be pre-computed, but does not seem to affect performance
293+
INV_SQRT_2PI = inv(sqrt(2pi))
281294
ω = INV_SQRT_2PI / s
282-
b1 = 1 / (24 * λ)
283-
b2 = 0.3 * b1^2
284-
c3 = b1 * b2 / 7
295+
b1 = inv(24) / λ
296+
b2 = 0.3 * b1 * b1
297+
c3 = inv(7) * b1 * b2
285298
c2 = b2 - 15 * c3
286299
c1 = b1 - 6 * b2 + 45 * c3
287300
c0 = 1 - b1 + 3 * b2 - 15 * c3
288301

289302
if K < 10
290303
px = -λ
291-
log_py = K * log(λ) - loggamma(K + 1) # log(K!) via loggamma
304+
log_py = K * log(λ) - loggamma(K + 1) # log(K!) via loggamma
292305
py = exp(log_py)
293306
else
294-
δ = 1 / (12 * K)
307+
δ = inv(12) / K
295308
δ -= 4.8 * δ^3
296309
V =- K) / K
297-
px = K * log1pmx(V) - δ
310+
px = K * log1pmx(V) - δ # avoids need for table
298311
py = INV_SQRT_2PI / sqrt(K)
299312
end
300-
301313
X = (K - λ + 0.5) / s
302314
X2 = X^2
303-
fx = -X2 / 2
315+
fx = X2 / -2 # missing negation in pseudo-algorithm, but appears in fortran code.
304316
fy = ω * (((c3 * X2 + c2) * X2 + c1) * X2 + c0)
305317
return px, py, fx, fy
306318
end
@@ -316,16 +328,19 @@ Generates Poisson(λ) distributed random numbers using a fast polyalgorithm.
316328
## Examples
317329
318330
```julia
319-
# Simple Poisson random which works on GPU
331+
# Simple Poisson random
320332
pois_rand(λ)
321333
322-
# Using RNG
334+
# Using another RNG
323335
using RandomNumbers
324336
rng = Xorshifts.Xoroshiro128Plus()
325337
pois_rand(rng, λ)
338+
339+
# Simple Poisson random on GPU
340+
pois_rand(PassthroughRNG(), λ)
326341
```
327342
"""
328-
pois_rand(λ) = λ < 6 ? count_rand(PassthroughRNG(), λ) : ad_rand(PassthroughRNG(), λ)
343+
pois_rand(λ) = pois_rand(Random.GLOBAL_RNG, λ)
329344
pois_rand(rng::AbstractRNG, λ) = λ < 6 ? count_rand(rng, λ) : ad_rand(rng, λ)
330345

331346
end

0 commit comments

Comments
 (0)