|
1 | 1 | ## Dirichlet ## |
2 | 2 |
|
3 | | -struct TuringDirichlet{T, TV <: AbstractVector} <: ContinuousMultivariateDistribution |
| 3 | +struct TuringDirichlet{T<:Real,TV<:AbstractVector,S<:Real} <: ContinuousMultivariateDistribution |
4 | 4 | alpha::TV |
5 | 5 | alpha0::T |
6 | | - lmnB::T |
7 | | -end |
8 | | -Base.length(d::TuringDirichlet) = length(d.alpha) |
9 | | -function check(alpha) |
10 | | - all(ai -> ai > 0, alpha) || |
11 | | - throw(ArgumentError("Dirichlet: alpha must be a positive vector.")) |
12 | | -end |
13 | | - |
14 | | -function Distributions._rand!(rng::Random.AbstractRNG, |
15 | | - d::TuringDirichlet, |
16 | | - x::AbstractVector{<:Real}) |
17 | | - s = 0.0 |
18 | | - n = length(x) |
19 | | - α = d.alpha |
20 | | - for i in 1:n |
21 | | - @inbounds s += (x[i] = rand(rng, Gamma(α[i]))) |
22 | | - end |
23 | | - Distributions.multiply!(x, inv(s)) # this returns x |
| 6 | + lmnB::S |
24 | 7 | end |
25 | 8 |
|
26 | 9 | function TuringDirichlet(alpha::AbstractVector) |
27 | | - check(alpha) |
| 10 | + all(ai -> ai > 0, alpha) || |
| 11 | + throw(ArgumentError("Dirichlet: alpha must be a positive vector.")) |
| 12 | + |
28 | 13 | alpha0 = sum(alpha) |
29 | 14 | lmnB = sum(loggamma, alpha) - loggamma(alpha0) |
30 | | - T = promote_type(typeof(alpha0), typeof(lmnB)) |
31 | | - TV = typeof(alpha) |
32 | | - TuringDirichlet{T, TV}(alpha, alpha0, lmnB) |
33 | | -end |
34 | 15 |
|
35 | | -function TuringDirichlet(d::Integer, alpha::Real) |
36 | | - alpha0 = alpha * d |
37 | | - _alpha = fill(alpha, d) |
38 | | - lmnB = loggamma(alpha) * d - loggamma(alpha0) |
39 | | - T = promote_type(typeof(alpha0), typeof(lmnB)) |
40 | | - TV = typeof(_alpha) |
41 | | - TuringDirichlet{T, TV}(_alpha, alpha0, lmnB) |
42 | | -end |
43 | | -function TuringDirichlet(alpha::AbstractVector{T}) where {T <: Integer} |
44 | | - TuringDirichlet(float.(alpha)) |
| 16 | + return TuringDirichlet(alpha, alpha0, lmnB) |
45 | 17 | end |
46 | | -TuringDirichlet(d::Integer, alpha::Integer) = TuringDirichlet(d, Float64(alpha)) |
| 18 | +TuringDirichlet(d::Integer, alpha::Real) = TuringDirichlet(Fill(alpha, d)) |
47 | 19 |
|
| 20 | +# TODO: remove? |
| 21 | +TuringDirichlet(alpha::AbstractVector{<:Integer}) = TuringDirichlet(float.(alpha)) |
| 22 | +TuringDirichlet(d::Integer, alpha::Integer) = TuringDirichlet(d, float(alpha)) |
| 23 | + |
| 24 | +# TODO: remove and use `Dirichlet` only for `Tracker.TrackedVector` |
48 | 25 | Distributions.Dirichlet(alpha::AbstractVector) = TuringDirichlet(alpha) |
49 | 26 |
|
| 27 | +TuringDirichlet(d::Dirichlet) = TuringDirichlet(d.alpha, d.alpha0, d.lmnB) |
| 28 | + |
| 29 | +Base.length(d::TuringDirichlet) = length(d.alpha) |
| 30 | + |
| 31 | +# copied from Distributions |
| 32 | +# TODO: remove and use `Dirichlet`? |
| 33 | +function Distributions._rand!( |
| 34 | + rng::Random.AbstractRNG, |
| 35 | + d::TuringDirichlet, |
| 36 | + x::AbstractVector{<:Real}, |
| 37 | +) |
| 38 | + @inbounds for (i, αi) in zip(eachindex(x), d.alpha) |
| 39 | + x[i] = rand(rng, Gamma(αi)) |
| 40 | + end |
| 41 | + Distributions.multiply!(x, inv(sum(x))) # this returns x |
| 42 | +end |
| 43 | +function Distributions._rand!( |
| 44 | + rng::AbstractRNG, |
| 45 | + d::TuringDirichlet{<:Real,<:FillArrays.AbstractFill}, |
| 46 | + x::AbstractVector{<:Real} |
| 47 | +) |
| 48 | + rand!(rng, Gamma(FillArrays.getindex_value(d.alpha)), x) |
| 49 | + Distributions.multiply!(x, inv(sum(x))) # this returns x |
| 50 | +end |
| 51 | + |
50 | 52 | function Distributions._logpdf(d::TuringDirichlet, x::AbstractVector{<:Real}) |
51 | 53 | return simplex_logpdf(d.alpha, d.lmnB, x) |
52 | 54 | end |
|
0 commit comments