Skip to content

Commit 735f65a

Browse files
committed
remove Dirichlet fix
1 parent 8a6db2d commit 735f65a

File tree

2 files changed

+1
-90
lines changed

2 files changed

+1
-90
lines changed

src/multivariate.jl

Lines changed: 0 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -1,91 +1,3 @@
1-
## Dirichlet ##
2-
3-
struct TuringDirichlet{T, TV <: AbstractVector} <: ContinuousMultivariateDistribution
4-
alpha::TV
5-
alpha0::T
6-
lmnB::T
7-
end
8-
function check(alpha)
9-
all(ai -> ai > 0, alpha) ||
10-
throw(ArgumentError("Dirichlet: alpha must be a positive vector."))
11-
end
12-
Zygote.@nograd DistributionsAD.check
13-
14-
function TuringDirichlet(alpha::AbstractVector)
15-
check(alpha)
16-
alpha0 = sum(alpha)
17-
lmnB = sum(loggamma, alpha) - loggamma(alpha0)
18-
T = promote_type(typeof(alpha0), typeof(lmnB))
19-
TV = typeof(alpha)
20-
TuringDirichlet{T, TV}(alpha, alpha0, lmnB)
21-
end
22-
23-
function TuringDirichlet(d::Integer, alpha::Real)
24-
alpha0 = alpha * d
25-
_alpha = fill(alpha, d)
26-
lmnB = loggamma(alpha) * d - loggamma(alpha0)
27-
T = promote_type(typeof(alpha0), typeof(lmnB))
28-
TV = typeof(_alpha)
29-
TuringDirichlet{T, TV}(_alpha, alpha0, lmnB)
30-
end
31-
function TuringDirichlet(alpha::AbstractVector{T}) where {T <: Integer}
32-
Tf = float(T)
33-
TuringDirichlet(convert(AbstractVector{Tf}, alpha))
34-
end
35-
TuringDirichlet(d::Integer, alpha::Integer) = TuringDirichlet(d, float(alpha))
36-
37-
Distributions.Dirichlet(alpha::TrackedVector) = TuringDirichlet(alpha)
38-
Distributions.Dirichlet(d::Integer, alpha::TrackedReal) = TuringDirichlet(d, alpha)
39-
40-
function Distributions.logpdf(d::TuringDirichlet, x::AbstractVector)
41-
simplex_logpdf(d.alpha, d.lmnB, x)
42-
end
43-
function Distributions.logpdf(d::TuringDirichlet, x::AbstractMatrix)
44-
simplex_logpdf(d.alpha, d.lmnB, x)
45-
end
46-
function Distributions.logpdf(d::Dirichlet{T}, x::TrackedVecOrMat) where {T}
47-
TV = typeof(d.alpha)
48-
logpdf(TuringDirichlet{T, TV}(d.alpha, d.alpha0, d.lmnB), x)
49-
end
50-
51-
@adjoint function Distributions.Dirichlet(alpha)
52-
return pullback(TuringDirichlet, alpha)
53-
end
54-
@adjoint function Distributions.Dirichlet(d, alpha)
55-
return pullback(TuringDirichlet, d, alpha)
56-
end
57-
58-
function simplex_logpdf(alpha, lmnB, x::AbstractVector)
59-
sum((alpha .- 1) .* log.(x)) - lmnB
60-
end
61-
function simplex_logpdf(alpha, lmnB, x::AbstractMatrix)
62-
init = vcat(sum((alpha .- 1) .* log.(view(x, :, 1))))
63-
mapreduce(vcat, drop(eachcol(x), 1); init = init) do c
64-
sum((alpha .- 1) .* log.(c)) - lmnB
65-
end
66-
end
67-
68-
@grad function simplex_logpdf(alpha, lmnB, x::AbstractVector)
69-
simplex_logpdf(data(alpha), data(lmnB), data(x)), Δ -> begin
70-
.* log.(data(x)), -Δ, Δ .* (data(alpha) .- 1))
71-
end
72-
end
73-
@grad function simplex_logpdf(alpha, lmnB, x::AbstractMatrix)
74-
simplex_logpdf(data(alpha), data(lmnB), data(x)), Δ -> begin
75-
(log.(data(x)) * Δ, -sum(Δ), repeat(data(alpha) .- 1, 1, size(x, 2)) * Diagonal(Δ))
76-
end
77-
end
78-
79-
@adjoint function simplex_logpdf(alpha, lmnB, x::AbstractVector)
80-
return simplex_logpdf(alpha, lmnB, x), Δ ->.* log.(x), -Δ, Δ .* (alpha .- 1))
81-
end
82-
83-
@adjoint function simplex_logpdf(alpha, lmnB, x::AbstractMatrix)
84-
return simplex_logpdf(alpha, lmnB, x), Δ -> begin
85-
(log.(x) * Δ, -sum(Δ), repeat(alpha .- 1, 1, size(x, 2)) * Diagonal(Δ))
86-
end
87-
end
88-
891
## MvNormal ##
902

913
"""

test/distributions.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -215,8 +215,6 @@ separator()
215215
DistSpec(:MvLogNormal, (cov_vec,), norm_val_mat),
216216
DistSpec(:MvLogNormal, (Diagonal(cov_vec),), norm_val_mat),
217217
DistSpec(:(cov_num -> MvLogNormal(dim, cov_num)), (cov_num,), norm_val_mat),
218-
DistSpec(:Dirichlet, (alpha,), dir_val),
219-
DistSpec(:Dirichlet, (alpha,), dir_val),
220218
]
221219

222220
broken_mult_cont_dists = [
@@ -233,6 +231,7 @@ separator()
233231
DistSpec(:MvNormalCanon, (cov_mat,), norm_val_mat),
234232
DistSpec(:MvNormalCanon, (cov_vec,), norm_val_mat),
235233
DistSpec(:(cov_num -> MvNormalCanon(dim, cov_num)), (cov_num,), norm_val_mat),
234+
DistSpec(:Dirichlet, (alpha,), dir_val),
236235
# Test failure
237236
DistSpec(:MvNormal, (mean, cov_mat), norm_val_mat),
238237
DistSpec(:MvNormal, (cov_mat,), norm_val_mat),

0 commit comments

Comments
 (0)