Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@ uuid = "e409e4f3-bfea-5376-8464-e040bb5c01ab"
version = "0.4.4"

[deps]
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"

[compat]
Aqua = "0.8"
Distributions = "0.25"
LogExpFunctions = "0.3"
Random = "1.10"
Statistics = "1"
Test = "1"
Expand Down
66 changes: 14 additions & 52 deletions src/PoissonRandom.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
module PoissonRandom

using Random
using LogExpFunctions: log1pmx

export pois_rand

Expand All @@ -26,12 +27,12 @@ end
ad_rand(λ) = ad_rand(Random.GLOBAL_RNG, λ)
function ad_rand(rng::AbstractRNG, λ)
s = sqrt(λ)
d = 6.0 * λ^2
d = 6 * λ^2
L = floor(Int, λ - 1.1484)
# Step N
G = λ + s * randn(rng)

if G >= 0.0
if G >= 0
K = floor(Int, G)
# Step I
if K >= L
Expand All @@ -56,7 +57,7 @@ function ad_rand(rng::AbstractRNG, λ)
while true
# Step E
E = randexp(rng)
U = 2.0 * rand(rng) - 1.0
U = 2 * rand(rng) - 1
T = 1.8 + copysign(E, U)
if T <= -0.6744
continue
Expand All @@ -73,70 +74,31 @@ function ad_rand(rng::AbstractRNG, λ)
end
end

# log(1+x)-x
# accurate ~2ulps for -0.227 < x < 0.315
function log1pmx_kernel(x::Float64)
r = x / (x + 2.0)
t = r * r
w = @evalpoly(t,
6.66666666666666667e-1, # 2/3
4.00000000000000000e-1, # 2/5
2.85714285714285714e-1, # 2/7
2.22222222222222222e-1, # 2/9
1.81818181818181818e-1, # 2/11
1.53846153846153846e-1, # 2/13
1.33333333333333333e-1, # 2/15
1.17647058823529412e-1) # 2/17
hxsq = 0.5 * x * x
r * (hxsq + w * t) - hxsq
end

# use naive calculation or range reduction outside kernel range.
# accurate ~2ulps for all x
function log1pmx(x::Float64)
if !(-0.7 < x < 0.9)
return log1p(x) - x
elseif x > 0.315
u = (x - 0.5) / 1.5
return log1pmx_kernel(u) - 9.45348918918356180e-2 - 0.5 * u
elseif x > -0.227
return log1pmx_kernel(x)
elseif x > -0.4
u = (x + 0.25) / 0.75
return log1pmx_kernel(u) - 3.76820724517809274e-2 + 0.25 * u
elseif x > -0.6
u = (x + 0.5) * 2.0
return log1pmx_kernel(u) - 1.93147180559945309e-1 + 0.5 * u
else
u = (x + 0.625) / 0.375
return log1pmx_kernel(u) - 3.55829253011726237e-1 + 0.625 * u
end
end

# Procedure F
function procf(λ, K::Int, s::Float64)
# can be pre-computed, but does not seem to affect performance
ω = 0.3989422804014327 / s
b1 = 0.041666666666666664 / λ
INV_SQRT_2PI = inv(sqrt(2pi))
ω = INV_SQRT_2PI / s
b1 = inv(24) / λ
b2 = 0.3 * b1 * b1
c3 = 0.14285714285714285 * b1 * b2
c2 = b2 - 15.0 * c3
c1 = b1 - 6.0 * b2 + 45.0 * c3
c0 = 1.0 - b1 + 3.0 * b2 - 15.0 * c3
c3 = inv(7) * b1 * b2
c2 = b2 - 15 * c3
c1 = b1 - 6 * b2 + 45 * c3
c0 = 1 - b1 + 3 * b2 - 15 * c3

if K < 10
px = -float(λ)
py = λ^K / factorial(K)
else
δ = 0.08333333333333333 / K
δ = inv(12) / K
δ -= 4.8 * δ^3
V = (λ - K) / K
px = K * log1pmx(V) - δ # avoids need for table
py = 0.3989422804014327 / sqrt(K)
py = INV_SQRT_2PI / sqrt(K)
end
X = (K - λ + 0.5) / s
X2 = X^2
fx = -0.5 * X2 # missing negation in pseudo-algorithm, but appears in fortran code.
fx = X2 / -2 # missing negation in pseudo-algorithm, but appears in fortran code.
fy = ω * (((c3 * X2 + c2) * X2 + c1) * X2 + c0)
return px, py, fx, fy
end
Expand Down
Loading