Skip to content

Commit 989520a

Browse files
committed
Add missing init stuff to utils
1 parent b006f9e commit 989520a

File tree

1 file changed

+36
-0
lines changed

1 file changed

+36
-0
lines changed

src/utils.jl

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,3 +63,39 @@ function reconstruct!(r, d::MultivariateDistribution, val::AbstractVector, n::In
6363
r .= val
6464
return r
6565
end
66+
67+
68+
# ROBUST INITIALISATIONS
69+
# Uniform rand with range 2; ref: https://mc-stan.org/docs/2_19/reference-manual/initialization.html
70+
randrealuni() = Real(2rand())
71+
randrealuni(args...) = map(Real, 2rand(args...))
72+
73+
const Transformable = Union{TransformDistribution, SimplexDistribution, PDMatDistribution}
74+
75+
76+
#################################
77+
# Single-sample initialisations #
78+
#################################
79+
80+
init(dist::Transformable) = inittrans(dist)
81+
init(dist::Distribution) = rand(dist)
82+
83+
inittrans(dist::UnivariateDistribution) = invlink(dist, randrealuni())
84+
inittrans(dist::MultivariateDistribution) = invlink(dist, randrealuni(size(dist)[1]))
85+
inittrans(dist::MatrixDistribution) = invlink(dist, randrealuni(size(dist)...))
86+
87+
88+
################################
89+
# Multi-sample initialisations #
90+
################################
91+
92+
init(dist::Transformable, n::Int) = inittrans(dist, n)
93+
init(dist::Distribution, n::Int) = rand(dist, n)
94+
95+
inittrans(dist::UnivariateDistribution, n::Int) = invlink(dist, randrealuni(n))
96+
function inittrans(dist::MultivariateDistribution, n::Int)
97+
return invlink(dist, randrealuni(size(dist)[1], n))
98+
end
99+
function inittrans(dist::MatrixDistribution, n::Int)
100+
return invlink(dist, [randrealuni(size(dist)...) for _ in 1:n])
101+
end

0 commit comments

Comments
 (0)