Skip to content

Commit 6b83376

Browse files
authored
Merge pull request #29 from TuringLang/mt/categorical_dirichlet
Categorical and Dirichlet
2 parents b296d39 + cc8d716 commit 6b83376

File tree

3 files changed

+101
-2
lines changed

3 files changed

+101
-2
lines changed

src/multivariate.jl

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,91 @@
1+
## Dirichlet ##
2+
3+
struct TuringDirichlet{T, TV <: AbstractVector} <: ContinuousMultivariateDistribution
4+
alpha::TV
5+
alpha0::T
6+
lmnB::T
7+
end
8+
function check(alpha)
9+
all(ai -> ai > 0, alpha) ||
10+
throw(ArgumentError("Dirichlet: alpha must be a positive vector."))
11+
end
12+
Zygote.@nograd DistributionsAD.check
13+
14+
function TuringDirichlet(alpha::AbstractVector)
15+
check(alpha)
16+
alpha0 = sum(alpha)
17+
lmnB = sum(loggamma, alpha) - loggamma(alpha0)
18+
T = promote_type(typeof(alpha0), typeof(lmnB))
19+
TV = typeof(alpha)
20+
TuringDirichlet{T, TV}(alpha, alpha0, lmnB)
21+
end
22+
23+
function TuringDirichlet(d::Integer, alpha::Real)
24+
alpha0 = alpha * d
25+
_alpha = fill(alpha, d)
26+
lmnB = loggamma(alpha) * d - loggamma(alpha0)
27+
T = promote_type(typeof(alpha0), typeof(lmnB))
28+
TV = typeof(_alpha)
29+
TuringDirichlet{T, TV}(_alpha, alpha0, lmnB)
30+
end
31+
function TuringDirichlet(alpha::AbstractVector{T}) where {T <: Integer}
32+
Tf = float(T)
33+
TuringDirichlet(convert(AbstractVector{Tf}, alpha))
34+
end
35+
TuringDirichlet(d::Integer, alpha::Integer) = TuringDirichlet(d, Float64(alpha))
36+
37+
Distributions.Dirichlet(alpha::TrackedVector) = TuringDirichlet(alpha)
38+
Distributions.Dirichlet(d::Integer, alpha::TrackedReal) = TuringDirichlet(d, alpha)
39+
40+
function Distributions.logpdf(d::TuringDirichlet, x::AbstractVector)
41+
simplex_logpdf(d.alpha, d.lmnB, x)
42+
end
43+
function Distributions.logpdf(d::TuringDirichlet, x::AbstractMatrix)
44+
simplex_logpdf(d.alpha, d.lmnB, x)
45+
end
46+
function Distributions.logpdf(d::Dirichlet{T}, x::TrackedVecOrMat) where {T}
47+
TV = typeof(d.alpha)
48+
logpdf(TuringDirichlet{T, TV}(d.alpha, d.alpha0, d.lmnB), x)
49+
end
50+
51+
ZygoteRules.@adjoint function Distributions.Dirichlet(alpha)
52+
return pullback(TuringDirichlet, alpha)
53+
end
54+
ZygoteRules.@adjoint function Distributions.Dirichlet(d, alpha)
55+
return pullback(TuringDirichlet, d, alpha)
56+
end
57+
58+
function simplex_logpdf(alpha, lmnB, x::AbstractVector)
59+
sum((alpha .- 1) .* log.(x)) - lmnB
60+
end
61+
function simplex_logpdf(alpha, lmnB, x::AbstractMatrix)
62+
@views init = vcat(sum((alpha .- 1) .* log.(x[:,1])) - lmnB)
63+
mapreduce(vcat, drop(eachcol(x), 1); init = init) do c
64+
sum((alpha .- 1) .* log.(c)) - lmnB
65+
end
66+
end
67+
68+
Tracker.@grad function simplex_logpdf(alpha, lmnB, x::AbstractVector)
69+
simplex_logpdf(data(alpha), data(lmnB), data(x)), Δ -> begin
70+
.* log.(data(x)), -Δ, Δ .* (data(alpha) .- 1))
71+
end
72+
end
73+
Tracker.@grad function simplex_logpdf(alpha, lmnB, x::AbstractMatrix)
74+
simplex_logpdf(data(alpha), data(lmnB), data(x)), Δ -> begin
75+
(log.(data(x)) * Δ, -sum(Δ), repeat(data(alpha) .- 1, 1, size(x, 2)) * Diagonal(Δ))
76+
end
77+
end
78+
79+
ZygoteRules.@adjoint function simplex_logpdf(alpha, lmnB, x::AbstractVector)
80+
simplex_logpdf(alpha, lmnB, x), Δ ->.* log.(x), -Δ, Δ .* (alpha .- 1))
81+
end
82+
83+
ZygoteRules.@adjoint function simplex_logpdf(alpha, lmnB, x::AbstractMatrix)
84+
simplex_logpdf(alpha, lmnB, x), Δ -> begin
85+
(log.(x) * Δ, -sum(Δ), repeat(alpha .- 1, 1, size(x, 2)) * Diagonal(Δ))
86+
end
87+
end
88+
189
## MvNormal ##
290

391
"""
@@ -68,6 +156,7 @@ end
68156
function _logpdf(d::TuringDiagMvNormal, x::AbstractMatrix)
69157
return -((size(x, 1) * log(2π) + 2 * sum(log.(d.σ))) .+ vec(sum(abs2.((x .- d.m) ./ d.σ), dims=1))) ./ 2
70158
end
159+
71160
function _logpdf(d::TuringDenseMvNormal, x::AbstractVector)
72161
return -(length(x) * log(2π) + logdet(d.C) + sum(abs2.(zygote_ldiv(d.C.U', x .- d.m)))) / 2
73162
end

src/univariate.jl

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -337,4 +337,14 @@ function _dft_zygote(x::Vector{T}) where T
337337
end
338338
return copy(y)
339339
end
340-
=#
340+
=#
341+
342+
## Categorical ##
343+
344+
function Base.convert(
345+
::Type{Distributions.DiscreteNonParametric{T,P,Ts,Ps}},
346+
d::Distributions.DiscreteNonParametric{T,P,Ts,Ps},
347+
) where {T<:Real,P<:Real,Ts<:AbstractVector{T},Ps<:AbstractVector{P}}
348+
DiscreteNonParametric{T,P,Ts,Ps}(support(d), probs(d), check_args=false)
349+
end
350+

test/distributions.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,7 @@ separator()
215215
DistSpec(:MvLogNormal, (cov_vec,), norm_val_mat),
216216
DistSpec(:MvLogNormal, (Diagonal(cov_vec),), norm_val_mat),
217217
DistSpec(:(cov_num -> MvLogNormal(dim, cov_num)), (cov_num,), norm_val_mat),
218+
DistSpec(:Dirichlet, (alpha,), dir_val),
218219
]
219220

220221
broken_mult_cont_dists = [
@@ -231,7 +232,6 @@ separator()
231232
DistSpec(:MvNormalCanon, (cov_mat,), norm_val_mat),
232233
DistSpec(:MvNormalCanon, (cov_vec,), norm_val_mat),
233234
DistSpec(:(cov_num -> MvNormalCanon(dim, cov_num)), (cov_num,), norm_val_mat),
234-
DistSpec(:Dirichlet, (alpha,), dir_val),
235235
# Test failure
236236
DistSpec(:MvNormal, (mean, cov_mat), norm_val_mat),
237237
DistSpec(:MvNormal, (cov_mat,), norm_val_mat),

0 commit comments

Comments
 (0)