Skip to content

Commit ec8e84c

Browse files
committed
fix: rand! for AlphaSubGaussian
1 parent 45fe871 commit ec8e84c

File tree

2 files changed

+77
-62
lines changed

2 files changed

+77
-62
lines changed

src/AlphaStableDistributions.jl

Lines changed: 30 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ Base.@kwdef struct AlphaStable{T} <: Distributions.ContinuousUnivariateDistribut
1414
location::T = zero(α)
1515
end
1616

17+
AlphaStable::Integer, β::Integer, scale::Integer, location::Integer) = AlphaStable(float(α), float(β), float(scale), float(location))
18+
1719

1820
# sampler(d::AlphaStable) = error("Not implemented")
1921
# pdf(d::AlphaStable, x::Real) = error("Not implemented")
@@ -267,31 +269,33 @@ This implementation is based on the method in J.M. Chambers, C.L. Mallows
267269
and B.W. Stuck, "A Method for Simulating Stable Random Variables," JASA 71 (1976): 340-4.
268270
McCulloch's MATLAB implementation (1996) served as a reference in developing this code.
269271
"""
270-
function Base.rand(rng::AbstractRNG, d::AlphaStable)
272+
function Base.rand(rng::AbstractRNG, d::AlphaStable{T}) where {T<:Real}
271273
α=d.α; β=d.β; scale=d.scale; loc=d.location
272274
< 0.1 || α > 2) && throw(DomainError(α, "α must be in the range 0.1 to 2"))
273275
abs(β) > 1 && throw(DomainError(β, "β must be in the range -1 to 1"))
274-
ϕ = (rand(rng) - 0.5) * π
275-
if α == 1 && β == 0
276-
return loc + scale*tan(ϕ)
276+
ϕ = (rand(rng, T) - 0.5) * π
277+
if α == one(T) && β == zero(T)
278+
return loc + scale * tan(ϕ)
277279
end
278-
w = -log(rand(rng))
280+
w = -log(rand(rng, T))
279281
α == 2 && (return loc + 2*scale*sqrt(w)*sin(ϕ))
280-
β == 0 && (return loc + scale * ((cos((1-α)*ϕ) / w)^(1.0/α - 1) * sin* ϕ) / cos(ϕ)^(1.0/α)))
282+
β == zero(T) && (return loc + scale * ((cos((1-α)*ϕ) / w)^(one(T)/α - one(T)) * sin* ϕ) / cos(ϕ)^(one(T)/α)))
281283
cosϕ = cos(ϕ)
282-
if abs-1) > 1e-8
283-
ζ = β * tan*α/2)
284+
if abs - one(T)) > 1e-8
285+
ζ = β * tan * α / 2)
284286
= α * ϕ
285-
a1ϕ = (1-α) * ϕ
286-
return loc + scale * (( (sin(aϕ)+ζ*cos(aϕ))/cosϕ * ((cos(a1ϕ)+ζ*sin(a1ϕ))) / ((w*cosϕ)^((1-α)/α)) ))
287+
a1ϕ = (one(T) - α) * ϕ
288+
return loc + scale * (( (sin(aϕ) + ζ * cos(aϕ))/cosϕ * ((cos(a1ϕ) + ζ*sin(a1ϕ))) / ((w*cosϕ)^((1-α)/α)) ))
287289
end
288290
= π/2 + β*ϕ
289-
x = 2/π * (bϕ*tan(ϕ) - β*log/2*w*cosϕ/bϕ))
290-
α == 1 || (x += β * tan*α/2))
291+
x = 2/π * (bϕ * tan(ϕ) - β * log/2*w*cosϕ/bϕ))
292+
α == one(T) || (x += β * tan*α/2))
291293

292-
return loc + scale*x
294+
return loc + scale * x
293295
end
294296

297+
Base.eltype(::Type{<:AlphaStable{T}}) where {T<:AbstractFloat} = T
298+
295299

296300
"""
297301
@@ -318,12 +322,16 @@ The maximum acceptable size of `R` is `10x10`
318322
julia> x = rand(AlphaSubGaussian(n=1000))
319323
```
320324
"""
321-
Base.@kwdef struct AlphaSubGaussian{T,M<:AbstractMatrix} <: Distributions.ContinuousUnivariateDistribution
325+
Base.@kwdef struct AlphaSubGaussian{T<:AbstractFloat} <: Distributions.ContinuousUnivariateDistribution
322326
α::T = 1.50
323-
R::M = SMatrix{5,5}(collect(SymmetricToeplitz([1.0000, 0.5804, 0.2140, 0.1444, -0.0135])))
327+
R::AbstractMatrix{T} = SMatrix{5,5}(collect(SymmetricToeplitz([1.0000, 0.5804, 0.2140, 0.1444, -0.0135])))
324328
n::Int
325329
end
326330

331+
AlphaSubGaussian::T, n::Int) where {T<:AbstractFloat} = AlphaSubGaussian=α,
332+
R=SMatrix{5,5}(T.(collect(SymmetricToeplitz([1.0000, 0.5804, 0.2140, 0.1444, -0.0135])))),
333+
n=n)
334+
327335
"""
328336
Generates the conditional probability f(X2|X1) if [X1, X2] is a sub-Gaussian
329337
stable random vector such that X1(i)~X2~S(alpha,delta) and rho is the correlation
@@ -366,10 +374,10 @@ function subgausscondprobtabulate(α, x1, x2_ind, invRx1, invR, vjoint, nmin, nm
366374
end
367375

368376

369-
function Random.rand!(rng::AbstractRNG, d::AlphaSubGaussian, x::AbstractArray)
377+
function Random.rand!(rng::AbstractRNG, d::AlphaSubGaussian{T}, x::AbstractArray{T}) where {T<:Real}
370378
α=d.α; R=d.R; n=d.n
371379
length(x) >= n || throw(ArgumentError("length of x must be at least n"))
372-
α 1.10:0.01:1.98 || throw(DomainError(α, "α must lie within `1.10:0.01:1.98`"))
380+
α T.(1.10:0.01:1.98) || throw(DomainError(α, "α must lie within `1.10:0.01:1.98`"))
373381
m = size(R, 1)-1
374382
funk1 = x -> (2^α)*sin*α/2)*gamma((α+2)/2)*gamma((α+x)/2)/(gamma(x/2)*π*α/2)
375383
funk2 = x -> 4*gamma(x/α)/((α*2^2)*gamma(x/2)^2)
@@ -388,11 +396,11 @@ function Random.rand!(rng::AbstractRNG, d::AlphaSubGaussian, x::AbstractArray)
388396
nmax, nmin, res, rind, vjoint = matdict["Nmax"]::Float64, matdict["Nmin"]::Float64, matdict["res"]::Float64, vec(matdict["rind"])::Vector{Float64}, matdict["vJoint"]::Matrix{Float64}
389397
step = (log10(nmax)-log10(nmin))/res
390398
m>size(vjoint, 1)-1 && throw(DomainError(R, "The dimensions of `R` exceed the maximum possible 10x10"))
391-
A = rand(AlphaStable/2, 1.0, 2*cos*α/4)^(2.0/α), 0.0))
392-
T = rand(Chisq(m))
399+
A = rand(AlphaStable(T(α/2), one(T), T(2*cos*α/4)^(2.0/α)), zero(T)))
400+
CT = rand(Chisq(m))
393401
S = randn(m)
394402
S = S/sqrt(sum(abs2,S))
395-
xtmp = ((sigrootx1*sqrt(A*T))*S)'
403+
xtmp = ((sigrootx1*sqrt(A*CT))*S)'
396404
if n<=m
397405
copyto!(x, @view(xtmp[1:n]))
398406
else
@@ -421,7 +429,8 @@ function Random.rand!(rng::AbstractRNG, d::AlphaSubGaussian, x::AbstractArray)
421429
end
422430

423431

424-
Base.rand(rng::AbstractRNG, d::AlphaSubGaussian) = rand!(rng, d, zeros(d.n))
432+
Base.rand(rng::AbstractRNG, d::AlphaSubGaussian) = rand!(rng, d, zeros(eltype(d), d.n))
433+
Base.eltype(::Type{<:AlphaSubGaussian{T}}) where {T} = T
425434

426435
"""
427436
fit(d::Type{<:AlphaSubGaussian}, x, m; p=1.0)

test/runtests.jl

Lines changed: 47 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -3,57 +3,63 @@ using Test, Random, Distributions
33

44
@testset "AlphaStableDistributions.jl" begin
55

6+
sampletypes = [Float32,Float64]
67
stabletypes = [AlphaStable,SymmetricAlphaStable]
78
αs = [0.6:0.1:2,1:0.1:2]
8-
for (i, stabletype) in enumerate(stabletypes)
9-
for α in αs[i]
10-
d1 = AlphaStable=α)
11-
s = rand(d1, 100000)
9+
for sampletype sampletypes
10+
for (i, stabletype) in enumerate(stabletypes)
11+
for α in αs[i]
12+
d1 = AlphaStable=sampletype(α))
13+
s = rand(d1, 100000)
14+
@test eltype(s) == sampletype
1215

13-
d2 = fit(stabletype, s)
16+
d2 = fit(stabletype, s)
1417

15-
@test d1.α d2.α rtol=0.1
16-
stabletype != SymmetricAlphaStable && @test d1.β d2.β atol=0.2
17-
@test d1.scale d2.scale rtol=0.1
18-
@test d1.location d2.location atol=0.1
19-
end
18+
@test d1.α d2.α rtol=0.1
19+
stabletype != SymmetricAlphaStable && @test d1.β d2.β atol=0.2
20+
@test d1.scale d2.scale rtol=0.1
21+
@test d1.location d2.location atol=0.1
22+
end
2023

21-
xnormal = rand(Normal(3.0, 4.0), 96000)
22-
d = fit(stabletype, xnormal)
23-
@test d.α 2 rtol=0.2
24-
stabletype != SymmetricAlphaStable && @test d.β 0 atol=0.2
25-
@test d.scale 4/√2 rtol=0.2
26-
@test d.location 3 rtol=0.1
24+
xnormal = rand(Normal(3.0, 4.0), 96000)
25+
d = fit(stabletype, xnormal)
26+
@test d.α 2 rtol=0.2
27+
stabletype != SymmetricAlphaStable && @test d.β 0 atol=0.2
28+
@test d.scale 4/√2 rtol=0.2
29+
@test d.location 3 rtol=0.1
2730

28-
xcauchy = rand(Cauchy(3.0, 4.0), 96000)
29-
d = fit(stabletype, xcauchy)
30-
@test d.α 1 rtol=0.2
31-
stabletype != SymmetricAlphaStable && @test d.β 0 atol=0.2
32-
@test d.scale 4 rtol=0.2
33-
@test d.location 3 rtol=0.1
34-
end
31+
xcauchy = rand(Cauchy(3.0, 4.0), 96000)
32+
d = fit(stabletype, xcauchy)
33+
@test d.α 1 rtol=0.2
34+
stabletype != SymmetricAlphaStable && @test d.β 0 atol=0.2
35+
@test d.scale 4 rtol=0.2
36+
@test d.location 3 rtol=0.1
37+
end
3538

36-
for α in 1.1:0.1:1.9
37-
d = AlphaSubGaussian(n=96000, α=α)
38-
x = rand(d)
39-
x2 = copy(x)
40-
rand!(d, x2)
41-
@test x != x2
39+
for α in 1.1:0.1:1.9
40+
d = AlphaSubGaussian(sampletype(α), 96000)
41+
x = rand(d)
42+
@test eltype(x) == sampletype
43+
x2 = copy(x)
44+
rand!(d, x2)
45+
@test x != x2
4246

43-
d3 = fit(AlphaStable, x)
44-
@test d3.α α rtol=0.2
45-
@test d3.β 0 atol=0.2
46-
@test d3.scale 1 rtol=0.2
47-
@test d3.location 0 atol=0.1
48-
end
47+
d3 = fit(AlphaStable, x)
48+
@test d3.α α rtol=0.2
49+
@test d3.β 0 atol=0.2
50+
@test d3.scale 1 rtol=0.2
51+
@test d3.location 0 atol=0.1
52+
end
4953

50-
d4 = AlphaSubGaussian(n=96000)
51-
m = size(d4.R, 1)-1
52-
x = rand(d4)
53-
d5 = fit(AlphaSubGaussian, x, m, p=1.0)
54-
@test d4.α d5.α rtol=0.1
55-
@test d4.R d5.R rtol=0.1
54+
d4 = AlphaSubGaussian(sampletype(1.5), 96000)
55+
m = size(d4.R, 1) - 1
56+
x = rand(d4)
57+
@test eltype(x) == sampletype
58+
d5 = fit(AlphaSubGaussian, x, m, p=1.0)
59+
@test d4.α d5.α rtol=0.1
60+
@test d4.R d5.R rtol=0.1
5661

62+
end
5763
end
5864
# 362.499 ms (4620903 allocations: 227.64 MiB)
5965
# 346.520 ms (4621052 allocations: 209.62 MiB) # StaticArrays in outer fun

0 commit comments

Comments
 (0)