@@ -18,7 +18,8 @@ using ..DistributionsAD: DistributionsAD
1818
1919
2020import SpecialFunctions, NaNMath
21- import .. DistributionsAD: turing_chol, symm_turing_chol, _mv_categorical_logpdf, adapt_randn
21+ import .. DistributionsAD: turing_chol, symm_turing_chol, _mv_categorical_logpdf, adapt_randn,
22+ simplex_logpdf
2223import Base. Broadcast: materialize
2324import StatsFuns: logsumexp
2425
@@ -47,12 +48,25 @@ using ..DistributionsAD: TuringPoissonBinomial,
4748 TuringDirichlet,
4849 TuringScalMvNormal,
4950 TuringDiagMvNormal,
50- TuringDenseMvNormal
51+ TuringDenseMvNormal,
52+ VectorOfMultivariate,
53+ FillVectorOfMultivariate
5154
5255include (" reversediffx.jl" )
5356
5457adapt_randn (rng:: Random.AbstractRNG , x:: TrackedArray , dims... ) = adapt_randn (rng, value (x), dims... )
5558
59+ # without this definition tests of `VectorOfMultivariate` with `Dirichlet` fail
60+ # upstream bug caused by `view` + `track`: https://github.com/JuliaDiff/ReverseDiff.jl/pull/164
61+ function _logpdf (dist:: VectorOfMultivariate , x:: AbstractMatrix{<:TrackedReal} )
62+ return sum (i -> _logpdf (dist. dists[i], x[:, i]), axes (x, 2 ))
63+ end
64+
65+ # fix method ambiguity
66+ function _logpdf (dist:: FillVectorOfMultivariate , x:: AbstractMatrix{<:TrackedReal} )
67+ return loglikelihood (dist. dists. value, x)
68+ end
69+
5670function PoissonBinomial (p:: TrackedArray{<:Real} ; check_args= true )
5771 return TuringPoissonBinomial (p; check_args = check_args)
5872end
@@ -240,36 +254,60 @@ end
240254# zero mean,, constant variance
241255MvLogNormal (d:: Int , σ:: TrackedReal ) = TuringMvLogNormal (TuringMvNormal (d, σ))
242256
243- Dirichlet (alpha:: TrackedVector ) = TuringDirichlet (alpha)
257+ # Dirichlet
258+
259+ Dirichlet (alpha:: AbstractVector{<:TrackedReal} ) = TuringDirichlet (alpha)
244260Dirichlet (d:: Integer , alpha:: TrackedReal ) = TuringDirichlet (d, alpha)
245261
262+ function _logpdf (d:: Dirichlet , x:: AbstractVector{<:TrackedReal} )
263+ return _logpdf (TuringDirichlet (d. alpha, d. alpha0, d. lmnB), x)
264+ end
265+ function logpdf (d:: Dirichlet , x:: AbstractMatrix{<:TrackedReal} )
266+ return logpdf (TuringDirichlet (d. alpha, d. alpha0, d. lmnB), x)
267+ end
268+ function loglikelihood (d:: Dirichlet , x:: AbstractMatrix{<:TrackedReal} )
269+ return loglikelihood (TuringDirichlet (d. alpha, d. alpha0, d. lmnB), x)
270+ end
271+
272+ # default definition of `loglikelihood` yields gradients of zero?!
273+ # upstream bug caused by `view` + `track`: https://github.com/JuliaDiff/ReverseDiff.jl/pull/164
274+ function loglikelihood (d:: TuringDirichlet , x:: AbstractMatrix{<:TrackedReal} )
275+ return sum (i -> logpdf (d, x[:, i]), axes (x, 2 ))
276+ end
277+
246278for func_header in [
247- :(simplex_logpdf (alpha:: TrackedVector , lmnB:: Real , x:: AbstractVector )),
279+ :(simplex_logpdf (alpha:: AbstractVector{<:TrackedReal} , lmnB:: Real , x:: AbstractVector )),
248280 :(simplex_logpdf (alpha:: AbstractVector , lmnB:: TrackedReal , x:: AbstractVector )),
249- :(simplex_logpdf (alpha:: AbstractVector , lmnB:: Real , x:: TrackedVector )),
250- :(simplex_logpdf (alpha:: TrackedVector , lmnB:: TrackedReal , x:: AbstractVector )),
251- :(simplex_logpdf (alpha:: AbstractVector , lmnB:: TrackedReal , x:: TrackedVector )),
252- :(simplex_logpdf (alpha:: TrackedVector , lmnB:: Real , x:: TrackedVector )),
253- :(simplex_logpdf (alpha:: TrackedVector , lmnB:: TrackedReal , x:: TrackedVector )),
281+ :(simplex_logpdf (alpha:: AbstractVector , lmnB:: Real , x:: AbstractVector{<:TrackedReal} )),
282+ :(simplex_logpdf (alpha:: AbstractVector{<:TrackedReal} , lmnB:: TrackedReal , x:: AbstractVector )),
283+ :(simplex_logpdf (alpha:: AbstractVector , lmnB:: TrackedReal , x:: AbstractVector{<:TrackedReal} )),
284+ :(simplex_logpdf (alpha:: AbstractVector{<:TrackedReal} , lmnB:: Real , x:: AbstractVector{<:TrackedReal} )),
285+ :(simplex_logpdf (alpha:: AbstractVector{<:TrackedReal} , lmnB:: TrackedReal , x:: AbstractVector{<:TrackedReal} )),
254286
255- :(simplex_logpdf (alpha:: TrackedVector , lmnB:: Real , x:: AbstractMatrix )),
287+ :(simplex_logpdf (alpha:: AbstractVector{<:TrackedReal} , lmnB:: Real , x:: AbstractMatrix )),
256288 :(simplex_logpdf (alpha:: AbstractVector , lmnB:: TrackedReal , x:: AbstractMatrix )),
257- :(simplex_logpdf (alpha:: AbstractVector , lmnB:: Real , x:: TrackedMatrix )),
258- :(simplex_logpdf (alpha:: TrackedVector , lmnB:: TrackedReal , x:: AbstractMatrix )),
259- :(simplex_logpdf (alpha:: AbstractVector , lmnB:: TrackedReal , x:: TrackedMatrix )),
260- :(simplex_logpdf (alpha:: TrackedVector , lmnB:: Real , x:: TrackedMatrix )),
261- :(simplex_logpdf (alpha:: TrackedVector , lmnB:: TrackedReal , x:: TrackedMatrix )),
289+ :(simplex_logpdf (alpha:: AbstractVector , lmnB:: Real , x:: AbstractMatrix{<:TrackedReal} )),
290+ :(simplex_logpdf (alpha:: AbstractVector{<:TrackedReal} , lmnB:: TrackedReal , x:: AbstractMatrix )),
291+ :(simplex_logpdf (alpha:: AbstractVector , lmnB:: TrackedReal , x:: AbstractMatrix{<:TrackedReal} )),
292+ :(simplex_logpdf (alpha:: AbstractVector{<:TrackedReal} , lmnB:: Real , x:: AbstractMatrix{<:TrackedReal} )),
293+ :(simplex_logpdf (alpha:: AbstractVector{<:TrackedReal} , lmnB:: TrackedReal , x:: AbstractMatrix{<:TrackedReal} )),
262294]
263295 @eval $ func_header = track (simplex_logpdf, alpha, lmnB, x)
264296end
265297@grad function simplex_logpdf (alpha, lmnB, x:: AbstractVector )
266- simplex_logpdf (value (alpha), value (lmnB), value (x)), Δ -> begin
267- (Δ .* log .(value (x)), - Δ, Δ .* (value (alpha) .- 1 ))
298+ _alpha = value (alpha)
299+ _lmnB = value (lmnB)
300+ _x = value (x)
301+ simplex_logpdf (_alpha, _lmnB, _x), Δ -> begin
302+ (Δ .* log .(_x), - Δ, Δ .* (_alpha .- 1 ) ./ _x)
268303 end
269304end
270305@grad function simplex_logpdf (alpha, lmnB, x:: AbstractMatrix )
271- simplex_logpdf (value (alpha), value (lmnB), value (x)), Δ -> begin
272- (log .(value (x)) * Δ, - sum (Δ), repeat (value (alpha) .- 1 , 1 , size (x, 2 )) * Diagonal (Δ))
306+ _alpha = value (alpha)
307+ _lmnB = value (lmnB)
308+ _x = value (x)
309+ simplex_logpdf (_alpha, _lmnB, _x), Δ -> begin
310+ (log .(_x) * Δ, - sum (Δ), ((_alpha .- 1 ) ./ _x) * Diagonal (Δ))
273311 end
274312end
275313
0 commit comments