Skip to content

Commit 747a191

Browse files
committed
Decouple rand and eltype
1 parent a1010e4 commit 747a191

19 files changed

+166
-36
lines changed

src/genericrand.jl

Lines changed: 11 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -30,35 +30,23 @@ function rand(rng::AbstractRNG, s::Sampleable{<:ArrayLikeVariate})
3030
end
3131

3232
# multiple samples
33-
function rand(rng::AbstractRNG, s::Sampleable{Univariate}, dims::Dims)
34-
out = Array{eltype(s)}(undef, dims)
35-
return @inbounds rand!(rng, sampler(s), out)
33+
# we use function barriers since for some distributions `sampler(s)` is not type-stable:
34+
# https://github.com/JuliaStats/Distributions.jl/pull/1281
35+
function rand(rng::AbstractRNG, s::Sampleable{<:ArrayLikeVariate}, dims::Dims)
36+
return _rand(rng, sampler(s), dims)
3637
end
37-
function rand(
38-
rng::AbstractRNG, s::Sampleable{<:ArrayLikeVariate}, dims::Dims,
39-
)
40-
sz = size(s)
41-
ax = map(Base.OneTo, dims)
42-
out = [Array{eltype(s)}(undef, sz) for _ in Iterators.product(ax...)]
43-
return @inbounds rand!(rng, sampler(s), out, false)
38+
function _rand(rng::AbstractRNG, s::Sampleable{<:ArrayLikeVariate}, dims::Dims)
39+
r = rand(rng, s)
40+
out = Array{typeof(r)}(undef, dims)
41+
out[begin] = r
42+
rand!(rng, s, @view(out[2:end]))
43+
return out
4444
end
4545

46-
# these are workarounds for sampleables that incorrectly base `eltype` on the parameters
46+
# this is a workaround for sampleables that incorrectly base `eltype` on the parameters
4747
function rand(rng::AbstractRNG, s::Sampleable{<:ArrayLikeVariate,Continuous})
4848
return @inbounds rand!(rng, sampler(s), Array{float(eltype(s))}(undef, size(s)))
4949
end
50-
function rand(rng::AbstractRNG, s::Sampleable{Univariate,Continuous}, dims::Dims)
51-
out = Array{float(eltype(s))}(undef, dims)
52-
return @inbounds rand!(rng, sampler(s), out)
53-
end
54-
function rand(
55-
rng::AbstractRNG, s::Sampleable{<:ArrayLikeVariate,Continuous}, dims::Dims,
56-
)
57-
sz = size(s)
58-
ax = map(Base.OneTo, dims)
59-
out = [Array{float(eltype(s))}(undef, sz) for _ in Iterators.product(ax...)]
60-
return @inbounds rand!(rng, sampler(s), out, false)
61-
end
6250

6351
"""
6452
rand!([rng::AbstractRNG,] s::Sampleable, A::AbstractArray)

src/multivariate/dirichlet.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,15 @@ end
154154

155155
# sampling
156156

157+
function rand(rng::AbstractRNG, d::Union{Dirichlet,DirichletCanon})
158+
x = map(αi -> rand(rng, Gamma(αi)), d.alpha)
159+
return lmul!(inv(sum(x)), x)
160+
end
161+
function rand(rng::AbstractRNG, d::Dirichlet{<:Real,<:FillArrays.AbstractFill{<:Real}})
162+
x = rand(rng, Gamma(FillArrays.getindex_value(d.alpha)), length(d))
163+
return lmul!(inv(sum(x)), x)
164+
end
165+
157166
function _rand!(rng::AbstractRNG,
158167
d::Union{Dirichlet,DirichletCanon},
159168
x::AbstractVector{<:Real})

src/multivariate/dirichletmultinomial.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,8 @@ end
9797

9898

9999
# Sampling
100+
rand(rng::AbstractRNG, d::DirichletMultinomial) =
101+
multinom_rand(rng, ntrials(d), rand(rng, Dirichlet(d.α)))
100102
_rand!(rng::AbstractRNG, d::DirichletMultinomial, x::AbstractVector{<:Real}) =
101103
multinom_rand!(rng, ntrials(d), rand(rng, Dirichlet(d.α)), x)
102104

src/multivariate/jointorderstatistics.jl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,27 @@ function _marginalize_range(dist, i, j, xᵢ, xⱼ, T)
125125
return k * T(logdiffcdf(dist, xⱼ, xᵢ)) - loggamma(T(k + 1))
126126
end
127127

128+
function rand(rng::AbstractRNG, d::JointOrderStatistics)
129+
n = d.n
130+
if n == length(d.ranks) # ranks == 1:n
131+
# direct method, slower than inversion method for large `n` and distributions with
132+
# fast quantile function or that use inversion sampling
133+
x = rand(rng, d.dist, n)
134+
sort!(x)
135+
else
136+
# use exponential generation method with inversion, where for gaps in the ranks, we
137+
# use the fact that the sum Y of k IID variables xₘ ~ Exp(1) is Y ~ Gamma(k, 1).
138+
# Lurie, D., and H. O. Hartley. "Machine-generation of order statistics for Monte
139+
# Carlo computations." The American Statistician 26.1 (1972): 26-27.
140+
# this is slow if length(d.ranks) is close to n and quantile for d.dist is expensive,
141+
# but this branch is probably taken when length(d.ranks) is small or much smaller than n.
142+
xi = rand(rng, d.dist) # this is only used to obtain the type of samples from `d.dist`
143+
x = Vector{typeof(xi)}(undef, length(d.ranks))
144+
_rand!(rng, d, x)
145+
end
146+
return x
147+
end
148+
128149
function _rand!(rng::AbstractRNG, d::JointOrderStatistics, x::AbstractVector{<:Real})
129150
n = d.n
130151
if n == length(d.ranks) # ranks == 1:n

src/multivariate/multinomial.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,7 @@ end
165165
# Sampling
166166

167167
# if only a single sample is requested, no alias table is created
168+
rand(rng::AbstractRNG, d::Multinomial) = multinom_rand(rng, ntrials(d), probs(d))
168169
_rand!(rng::AbstractRNG, d::Multinomial, x::AbstractVector{<:Real}) =
169170
multinom_rand!(rng, ntrials(d), probs(d), x)
170171

src/multivariate/mvlogitnormal.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,19 @@ kldivergence(p::MvLogitNormal, q::MvLogitNormal) = kldivergence(p.normal, q.norm
8888

8989
# Sampling
9090

91+
function rand(rng::AbstractRNG, d::MvLogitNormal)
92+
x = rand(rng, d.normal)
93+
push!(x, zero(eltype(x)))
94+
StatsFuns.softmax!(x)
95+
return x
96+
end
97+
function rand(rng::AbstractRNG, d::MvLogitNormal, n::Int)
98+
r = rand(rng, d.normal, n)
99+
x = vcat(r, zeros(eltype(r), 1, n))
100+
StatsFuns.softmax!(x; dims=1)
101+
return x
102+
end
103+
91104
function _rand!(rng::AbstractRNG, d::MvLogitNormal, x::AbstractVecOrMat{<:Real})
92105
y = @views _drop1(x)
93106
rand!(rng, d.normal, y)

src/multivariate/mvlognormal.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,17 @@ var(d::MvLogNormal) = diag(cov(d))
232232
entropy(d::MvLogNormal) = length(d)*(1+log2π)/2 + logdetcov(d.normal)/2 + sum(mean(d.normal))
233233

234234
#See https://en.wikipedia.org/wiki/Log-normal_distribution
235+
function rand(rng::AbstractRNG, d::MvLogNormal)
236+
x = rand(rng, d.normal)
237+
map!(exp, x, x)
238+
return x
239+
end
240+
function rand(rng::AbstractRNG, d::MvLogNormal, n::Int)
241+
xs = rand(rng, d.normal, n)
242+
map!(exp, xs, xs)
243+
return xs
244+
end
245+
235246
function _rand!(rng::AbstractRNG, d::MvLogNormal, x::AbstractVecOrMat{<:Real})
236247
_rand!(rng, d.normal, x)
237248
map!(exp, x, x)

src/multivariate/mvnormal.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,17 @@ gradlogpdf(d::MvNormal, x::AbstractVector{<:Real}) = -(d.Σ \ (x .- d.μ))
273273

274274
# Sampling (for GenericMvNormal)
275275

276+
function rand(rng::AbstractRNG, d::MvNormal)
277+
x = unwhiten!(d.Σ, randn(rng, float(partype(d)), length(d)))
278+
x .+= d.μ
279+
return x
280+
end
281+
function rand(rng::AbstractRNG, d::MvNormal, n::Int)
282+
x = unwhiten!(d.Σ, randn(rng, float(partype(d)), length(d), n))
283+
x .+= d.μ
284+
return x
285+
end
286+
276287
function _rand!(rng::AbstractRNG, d::MvNormal, x::VecOrMat)
277288
unwhiten!(d.Σ, randn!(rng, x))
278289
x .+= d.μ

src/multivariate/mvnormalcanon.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,17 @@ if isdefined(PDMats, :PDSparseMat)
177177
unwhiten_winv!(J::PDSparseMat, x::AbstractVecOrMat) = x[:] = J.chol.PtL' \ x
178178
end
179179

180+
function rand(rng::AbstractRNG, d::MvNormalCanon)
181+
x = unwhiten_winv!(d.J, randn(rng, float(partype(d)), length(d)))
182+
x .+= d.μ
183+
return x
184+
end
185+
function rand(rng::AbstractRNG, d::MvNormalCanon, n::Int)
186+
x = unwhiten_winv!(d.J, randn(rng, float(partype(d)), length(d), n))
187+
x .+= d.μ
188+
return x
189+
end
190+
180191
function _rand!(rng::AbstractRNG, d::MvNormalCanon, x::AbstractVector)
181192
unwhiten_winv!(d.J, randn!(rng, x))
182193
x .+= d.μ

src/multivariate/mvtdist.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,21 @@ function gradlogpdf(d::GenericMvTDist, x::AbstractVector{<:Real})
155155
end
156156

157157
# Sampling (for GenericMvTDist)
158+
function rand(rng::AbstractRNG, d::GenericMvTDist)
159+
chisqd = Chisq{partype(d)}(d.df)
160+
y = sqrt(rand(rng, chisqd) / d.df)
161+
x = unwhiten!(d.Σ, randn(rng, typeof(y), length(d)))
162+
x .= x ./ y .+ d.μ
163+
x
164+
end
165+
function rand(rng::AbstractRNG, d::GenericMvTDist, n::Int)
166+
chisqd = Chisq{partype(d)}(d.df)
167+
y = rand(rng, chisqd, n)
168+
x = unwhiten!(d.Σ, randn(rng, eltype(y), length(d), n))
169+
x .= x ./ sqrt.(y' ./ d.df) .+ d.μ
170+
x
171+
end
172+
158173
function _rand!(rng::AbstractRNG, d::GenericMvTDist, x::AbstractVector{<:Real})
159174
chisqd = Chisq{partype(d)}(d.df)
160175
y = sqrt(rand(rng, chisqd) / d.df)

0 commit comments

Comments
 (0)