Skip to content

Commit bf62c27

Browse files
replaced with pois_rand with PassthroughRNG
1 parent 42c2257 commit bf62c27

File tree

3 files changed

+3
-141
lines changed

3 files changed

+3
-141
lines changed

Project.toml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,13 @@ DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
1212
FunctionWrappers = "069b7b12-0de2-55c6-9aab-29f3d0a68a2e"
1313
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
1414
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
15-
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
1615
Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a"
1716
PoissonRandom = "e409e4f3-bfea-5376-8464-e040bb5c01ab"
1817
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1918
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
2019
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
2120
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
2221
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
23-
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
2422
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
2523
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
2624
UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"

ext/JumpProcessesKernelAbstractionsExt.jl

Lines changed: 2 additions & 138 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,7 @@ module JumpProcessesKernelAbstractionsExt
33
using JumpProcesses, SciMLBase
44
using KernelAbstractions, Adapt
55
using StaticArrays
6-
using Random
7-
using LogExpFunctions: log1pmx
8-
using SpecialFunctions: loggamma
6+
using PoissonRandom, Random
97

108
function SciMLBase.__solve(ensembleprob::SciMLBase.AbstractEnsembleProblem,
119
alg::SimpleTauLeaping,
@@ -134,7 +132,7 @@ end
134132

135133
# Poisson sampling
136134
@inbounds for k in 1:num_jumps
137-
counts[k] = pois_rand(PassthroughRNG(), rate_cache[k])
135+
counts[k] = pois_rand(PoissonRandom.PassthroughRNG(), rate_cache[k])
138136
end
139137

140138
# Apply changes
@@ -211,138 +209,4 @@ function vectorized_solve(probs, prob::JumpProblem, alg::SimpleTauLeaping;
211209
return ts, us
212210
end
213211

214-
export pois_rand, PassthroughRNG
215-
216-
# GPU-compatible Poisson sampling via PassthroughRNG
217-
struct PassthroughRNG <: AbstractRNG end
218-
219-
rand(rng::PassthroughRNG) = Random.rand()
220-
randexp(rng::PassthroughRNG) = Random.randexp()
221-
randn(rng::PassthroughRNG) = Random.randn()
222-
223-
count_rand(λ) = count_rand(Random.GLOBAL_RNG, λ)
224-
function count_rand(rng::AbstractRNG, λ)
225-
n = 0
226-
c = randexp(rng)
227-
while c < λ
228-
n += 1
229-
c += randexp(rng)
230-
end
231-
return n
232-
end
233-
234-
# Algorithm from:
235-
#
236-
# J.H. Ahrens, U. Dieter (1982)
237-
# "Computer Generation of Poisson Deviates from Modified Normal Distributions"
238-
# ACM Transactions on Mathematical Software, 8(2):163-179
239-
#
240-
# For μ sufficiently large, (i.e. >= 10.0)
241-
#
242-
ad_rand(λ) = ad_rand(Random.GLOBAL_RNG, λ)
243-
function ad_rand(rng::AbstractRNG, λ)
244-
s = sqrt(λ)
245-
d = 6 * λ^2
246-
L = floor(Int, λ - 1.1484)
247-
# Step N
248-
G = λ + s * randn(rng)
249-
250-
if G >= 0
251-
K = floor(Int, G)
252-
# Step I
253-
if K >= L
254-
return K
255-
end
256-
257-
# Step S
258-
U = rand(rng)
259-
if d * U >=- K)^3
260-
return K
261-
end
262-
263-
# Step P
264-
px, py, fx, fy = procf(λ, K, s)
265-
266-
# Step Q
267-
if fy * (1 - U) <= py * exp(px - fx)
268-
return K
269-
end
270-
end
271-
272-
while true
273-
# Step E
274-
E = randexp(rng)
275-
U = 2 * rand(rng) - 1
276-
T = 1.8 + copysign(E, U)
277-
if T <= -0.6744
278-
continue
279-
end
280-
281-
K = floor(Int, λ + s * T)
282-
px, py, fx, fy = procf(λ, K, s)
283-
c = 0.1069 / λ
284-
285-
# Step H
286-
@fastmath if c * abs(U) <= py * exp(px + E) - fy * exp(fx + E)
287-
return K
288-
end
289-
end
290-
end
291-
292-
# Procedure F
293-
function procf(λ, K::Int, s::Float64)
294-
# can be pre-computed, but does not seem to affect performance
295-
INV_SQRT_2PI = inv(sqrt(2pi))
296-
ω = INV_SQRT_2PI / s
297-
b1 = inv(24) / λ
298-
b2 = 0.3 * b1 * b1
299-
c3 = inv(7) * b1 * b2
300-
c2 = b2 - 15 * c3
301-
c1 = b1 - 6 * b2 + 45 * c3
302-
c0 = 1 - b1 + 3 * b2 - 15 * c3
303-
304-
if K < 10
305-
px = -float(λ)
306-
log_py = K * log(λ) - loggamma(K + 1) # log(K!) via loggamma
307-
py = exp(log_py)
308-
else
309-
δ = inv(12) / K
310-
δ -= 4.8 * δ^3
311-
V =- K) / K
312-
px = K * log1pmx(V) - δ # avoids need for table
313-
py = INV_SQRT_2PI / sqrt(K)
314-
end
315-
X = (K - λ + 0.5) / s
316-
X2 = X^2
317-
fx = X2 / -2 # missing negation in pseudo-algorithm, but appears in fortran code.
318-
fy = ω * (((c3 * X2 + c2) * X2 + c1) * X2 + c0)
319-
return px, py, fx, fy
320-
end
321-
322-
"""
323-
```julia
324-
pois_rand(λ)
325-
pois_rand(rng::AbstractRNG, λ)
326-
```
327-
328-
Generates Poisson(λ) distributed random numbers using a fast polyalgorithm.
329-
330-
## Examples
331-
332-
```julia
333-
# Simple Poisson random
334-
pois_rand(λ)
335-
336-
# Using another RNG
337-
using RandomNumbers
338-
rng = Xorshifts.Xoroshiro128Plus()
339-
pois_rand(rng, λ)
340-
341-
# Simple Poisson random on GPU
342-
pois_rand(PoissonRandom.PassthroughRNG(), λ)
343-
```
344-
"""
345-
pois_rand(λ) = pois_rand(Random.GLOBAL_RNG, λ)
346-
pois_rand(rng::AbstractRNG, λ) = λ < 6 ? count_rand(rng, λ) : ad_rand(rng, λ)
347-
348212
end

test/gpu/regular_jumps.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ using KernelAbstractions, Adapt, CUDA
44
using StableRNGs
55
rng = StableRNG(12345)
66

7-
Nsims = 100_000
7+
Nsims = 100
88

99
# SIR model with influx
1010
let

0 commit comments

Comments
 (0)