Skip to content

Commit 7f7651b

Browse files
committed
fix Dirichlet
1 parent f8cf678 commit 7f7651b

File tree

3 files changed

+91
-2
lines changed

3 files changed

+91
-2
lines changed

src/DistributionsAD.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ using Distributions: AbstractMvLogNormal,
1818
ContinuousMultivariateDistribution
1919
using DiffRules, SpecialFunctions
2020
using ForwardDiff: @define_binary_dual_op # Needed for `eval`ing diffrules here
21+
using Base.Iterators: drop
2122

2223
import StatsFuns: logsumexp,
2324
binomlogpdf,

src/multivariate.jl

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,91 @@
1+
## Dirichlet ##
2+
3+
struct TuringDirichlet{T, TV <: AbstractVector} <: ContinuousMultivariateDistribution
4+
alpha::TV
5+
alpha0::T
6+
lmnB::T
7+
end
8+
function check(alpha)
9+
all(ai -> ai > 0, alpha) ||
10+
throw(ArgumentError("Dirichlet: alpha must be a positive vector."))
11+
end
12+
Zygote.@nograd DistributionsAD.check
13+
14+
function TuringDirichlet(alpha::AbstractVector)
15+
check(alpha)
16+
alpha0 = sum(alpha)
17+
lmnB = sum(loggamma, alpha) - loggamma(alpha0)
18+
T = promote_type(typeof(alpha0), typeof(lmnB))
19+
TV = typeof(alpha)
20+
TuringDirichlet{T, TV}(alpha, alpha0, lmnB)
21+
end
22+
23+
function TuringDirichlet(d::Integer, alpha::Real)
24+
alpha0 = alpha * d
25+
_alpha = fill(alpha, d)
26+
lmnB = loggamma(alpha) * d - loggamma(alpha0)
27+
T = promote_type(typeof(alpha0), typeof(lmnB))
28+
TV = typeof(_alpha)
29+
TuringDirichlet{T, TV}(_alpha, alpha0, lmnB)
30+
end
31+
function TuringDirichlet(alpha::AbstractVector{T}) where {T <: Integer}
32+
Tf = float(T)
33+
TuringDirichlet(convert(AbstractVector{Tf}, alpha))
34+
end
35+
TuringDirichlet(d::Integer, alpha::Integer) = TuringDirichlet(d, Float64(alpha))
36+
37+
Distributions.Dirichlet(alpha::TrackedVector) = TuringDirichlet(alpha)
38+
Distributions.Dirichlet(d::Integer, alpha::TrackedReal) = TuringDirichlet(d, alpha)
39+
40+
function Distributions.logpdf(d::TuringDirichlet, x::AbstractVector)
41+
simplex_logpdf(d.alpha, d.lmnB, x)
42+
end
43+
function Distributions.logpdf(d::TuringDirichlet, x::AbstractMatrix)
44+
simplex_logpdf(d.alpha, d.lmnB, x)
45+
end
46+
function Distributions.logpdf(d::Dirichlet{T}, x::TrackedVecOrMat) where {T}
47+
TV = typeof(d.alpha)
48+
logpdf(TuringDirichlet{T, TV}(d.alpha, d.alpha0, d.lmnB), x)
49+
end
50+
51+
ZygoteRules.@adjoint function Distributions.Dirichlet(alpha)
52+
return pullback(TuringDirichlet, alpha)
53+
end
54+
ZygoteRules.@adjoint function Distributions.Dirichlet(d, alpha)
55+
return pullback(TuringDirichlet, d, alpha)
56+
end
57+
58+
function simplex_logpdf(alpha, lmnB, x::AbstractVector)
59+
sum((alpha .- 1) .* log.(x)) - lmnB
60+
end
61+
function simplex_logpdf(alpha, lmnB, x::AbstractMatrix)
62+
@views init = vcat(sum((alpha .- 1) .* log.(x[:,1])))
63+
mapreduce(vcat, drop(eachcol(x), 1); init = init) do c
64+
sum((alpha .- 1) .* log.(c)) - lmnB
65+
end
66+
end
67+
68+
Tracker.@grad function simplex_logpdf(alpha, lmnB, x::AbstractVector)
69+
simplex_logpdf(data(alpha), data(lmnB), data(x)), Δ -> begin
70+
.* log.(data(x)), -Δ, Δ .* (data(alpha) .- 1))
71+
end
72+
end
73+
Tracker.@grad function simplex_logpdf(alpha, lmnB, x::AbstractMatrix)
74+
simplex_logpdf(data(alpha), data(lmnB), data(x)), Δ -> begin
75+
(log.(data(x)) * Δ, -sum(Δ), repeat(data(alpha) .- 1, 1, size(x, 2)) * Diagonal(Δ))
76+
end
77+
end
78+
79+
ZygoteRules.@adjoint function simplex_logpdf(alpha, lmnB, x::AbstractVector)
80+
simplex_logpdf(alpha, lmnB, x), Δ ->.* log.(x), -Δ, Δ .* (alpha .- 1))
81+
end
82+
83+
ZygoteRules.@adjoint function simplex_logpdf(alpha, lmnB, x::AbstractMatrix)
84+
simplex_logpdf(alpha, lmnB, x), Δ -> begin
85+
(log.(x) * Δ, -sum(Δ), repeat(alpha .- 1, 1, size(x, 2)) * Diagonal(Δ))
86+
end
87+
end
88+
189
## MvNormal ##
290

391
"""

test/distributions.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,8 @@ separator()
205205
DistSpec(:MvLogNormal, (cov_vec,), norm_val_mat),
206206
DistSpec(:MvLogNormal, (Diagonal(cov_vec),), norm_val_mat),
207207
DistSpec(:(cov_num -> MvLogNormal(dim, cov_num)), (cov_num,), norm_val_mat),
208+
DistSpec(:Dirichlet, (alpha,), dir_val),
209+
DistSpec(:Dirichlet, (alpha,), dir_val),
208210
]
209211

210212
broken_mult_cont_dists = [
@@ -215,14 +217,12 @@ separator()
215217
DistSpec(:MvNormalCanon, (cov_mat,), norm_val_vec),
216218
DistSpec(:MvNormalCanon, (cov_vec,), norm_val_vec),
217219
DistSpec(:(cov_num -> MvNormalCanon(dim, cov_num)), (cov_num,), norm_val_vec),
218-
DistSpec(:Dirichlet, (alpha,), dir_val),
219220
DistSpec(:MvNormalCanon, (mean, cov_mat), norm_val_mat),
220221
DistSpec(:MvNormalCanon, (mean, cov_vec), norm_val_mat),
221222
DistSpec(:MvNormalCanon, (mean, cov_num), norm_val_mat),
222223
DistSpec(:MvNormalCanon, (cov_mat,), norm_val_mat),
223224
DistSpec(:MvNormalCanon, (cov_vec,), norm_val_mat),
224225
DistSpec(:(cov_num -> MvNormalCanon(dim, cov_num)), (cov_num,), norm_val_mat),
225-
DistSpec(:Dirichlet, (alpha,), dir_val),
226226
# Test failure
227227
DistSpec(:MvNormal, (mean, cov_mat), norm_val_mat),
228228
DistSpec(:MvNormal, (cov_mat,), norm_val_mat),

0 commit comments

Comments
 (0)