Skip to content

Commit 4891647

Browse files
committed
fix Dirichlet
1 parent 5a8698e commit 4891647

File tree

1 file changed

+90
-0
lines changed

1 file changed

+90
-0
lines changed

src/multivariate.jl

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,95 @@ end
179179
function Distributions.Categorical(p::TrackedVector; check_args = true)
180180
return TuringDiscreteNonParametric(1:length(p), p, check_args = check_args)
181181
end
182+
183+
## Dirichlet ##
184+
185+
struct TuringDirichlet{T, TV <: AbstractVector} <: ContinuousMultivariateDistribution
186+
alpha::TV
187+
alpha0::T
188+
lmnB::T
189+
end
190+
function check(alpha)
191+
all(ai -> ai > 0, alpha) ||
192+
throw(ArgumentError("Dirichlet: alpha must be a positive vector."))
193+
end
194+
Zygote.@nograd DistributionsAD.check
195+
196+
function TuringDirichlet(alpha::AbstractVector)
197+
check(alpha)
198+
alpha0 = sum(alpha)
199+
lmnB = sum(loggamma, alpha) - loggamma(alpha0)
200+
T = promote_type(typeof(alpha0), typeof(lmnB))
201+
TV = typeof(alpha)
202+
TuringDirichlet{T, TV}(alpha, alpha0, lmnB)
203+
end
204+
205+
function TuringDirichlet(d::Integer, alpha::Real)
206+
alpha0 = alpha * d
207+
_alpha = fill(alpha, d)
208+
lmnB = loggamma(alpha) * d - loggamma(alpha0)
209+
T = promote_type(typeof(alpha0), typeof(lmnB))
210+
TV = typeof(_alpha)
211+
TuringDirichlet{T, TV}(_alpha, alpha0, lmnB)
212+
end
213+
function TuringDirichlet(alpha::AbstractVector{T}) where {T <: Integer}
214+
Tf = float(T)
215+
TuringDirichlet(convert(AbstractVector{Tf}, alpha))
216+
end
217+
TuringDirichlet(d::Integer, alpha::Integer) = TuringDirichlet(d, Float64(alpha))
218+
219+
Distributions.Dirichlet(alpha::TrackedVector) = TuringDirichlet(alpha)
220+
Distributions.Dirichlet(d::Integer, alpha::TrackedReal) = TuringDirichlet(d, alpha)
221+
222+
function Distributions.logpdf(d::TuringDirichlet, x::AbstractVector)
223+
simplex_logpdf(d.alpha, d.lmnB, x)
224+
end
225+
function Distributions.logpdf(d::TuringDirichlet, x::AbstractMatrix)
226+
simplex_logpdf(d.alpha, d.lmnB, x)
227+
end
228+
function Distributions.logpdf(d::Dirichlet{T}, x::TrackedVecOrMat) where {T}
229+
TV = typeof(d.alpha)
230+
logpdf(TuringDirichlet{T, TV}(d.alpha, d.alpha0, d.lmnB), x)
231+
end
232+
233+
ZygoteRules.@adjoint function Distributions.Dirichlet(alpha)
234+
return pullback(TuringDirichlet, alpha)
235+
end
236+
ZygoteRules.@adjoint function Distributions.Dirichlet(d, alpha)
237+
return pullback(TuringDirichlet, d, alpha)
238+
end
239+
240+
function simplex_logpdf(alpha, lmnB, x::AbstractVector)
241+
sum((alpha .- 1) .* log.(x)) - lmnB
242+
end
243+
function simplex_logpdf(alpha, lmnB, x::AbstractMatrix)
244+
@views init = vcat(sum((alpha .- 1) .* log.(x[:,1])))
245+
mapreduce(vcat, drop(eachcol(x), 1); init = init) do c
246+
sum((alpha .- 1) .* log.(c)) - lmnB
247+
end
248+
end
249+
250+
Tracker.@grad function simplex_logpdf(alpha, lmnB, x::AbstractVector)
251+
simplex_logpdf(data(alpha), data(lmnB), data(x)), Δ -> begin
252+
.* log.(data(x)), -Δ, Δ .* (data(alpha) .- 1))
253+
end
254+
end
255+
Tracker.@grad function simplex_logpdf(alpha, lmnB, x::AbstractMatrix)
256+
simplex_logpdf(data(alpha), data(lmnB), data(x)), Δ -> begin
257+
(log.(data(x)) * Δ, -sum(Δ), repeat(data(alpha) .- 1, 1, size(x, 2)) * Diagonal(Δ))
258+
end
259+
end
260+
261+
ZygoteRules.@adjoint function simplex_logpdf(alpha, lmnB, x::AbstractVector)
262+
simplex_logpdf(alpha, lmnB, x), Δ ->.* log.(x), -Δ, Δ .* (alpha .- 1))
263+
end
264+
265+
ZygoteRules.@adjoint function simplex_logpdf(alpha, lmnB, x::AbstractMatrix)
266+
simplex_logpdf(alpha, lmnB, x), Δ -> begin
267+
(log.(x) * Δ, -sum(Δ), repeat(alpha .- 1, 1, size(x, 2)) * Diagonal(Δ))
268+
end
269+
end
270+
182271
## MvNormal ##
183272

184273
"""
@@ -249,6 +338,7 @@ end
249338
function _logpdf(d::TuringDiagMvNormal, x::AbstractMatrix)
250339
return -((size(x, 1) * log(2π) + 2 * sum(log.(d.σ))) .+ vec(sum(abs2.((x .- d.m) ./ d.σ), dims=1))) ./ 2
251340
end
341+
252342
function _logpdf(d::TuringDenseMvNormal, x::AbstractVector)
253343
return -(length(x) * log(2π) + logdet(d.C) + sum(abs2.(zygote_ldiv(d.C.U', x .- d.m)))) / 2
254344
end

0 commit comments

Comments
 (0)