Skip to content

Commit 3053aae

Browse files
authored
Fix Dirichlet with ReverseDiff (#151)
1 parent 2a6622c commit 3053aae

File tree

6 files changed

+84
-34
lines changed

6 files changed

+84
-34
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "DistributionsAD"
22
uuid = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
3-
version = "0.6.16"
3+
version = "0.6.17"
44

55
[deps]
66
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

src/multivariate.jl

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -63,14 +63,9 @@ ZygoteRules.@adjoint function Distributions.Dirichlet(d, alpha)
6363
return ZygoteRules.pullback(TuringDirichlet, d, alpha)
6464
end
6565

66-
function simplex_logpdf(alpha, lmnB, x::AbstractVector)
67-
sum((alpha .- 1) .* log.(x)) - lmnB
68-
end
66+
simplex_logpdf(alpha, lmnB, x::AbstractVector) = sum(xlogy.(alpha .- 1, x)) - lmnB
6967
function simplex_logpdf(alpha, lmnB, x::AbstractMatrix)
70-
@views init = vcat(sum((alpha .- 1) .* log.(x[:,1])) - lmnB)
71-
mapreduce(vcat, drop(eachcol(x), 1); init = init) do c
72-
sum((alpha .- 1) .* log.(c)) - lmnB
73-
end
68+
return vec(sum(xlogy.(alpha .- 1, x); dims=1)) .- lmnB
7469
end
7570

7671
ZygoteRules.@adjoint function simplex_logpdf(alpha, lmnB, x::AbstractVector)

src/reversediff.jl

Lines changed: 57 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@ using ..DistributionsAD: DistributionsAD
1818

1919

2020
import SpecialFunctions, NaNMath
21-
import ..DistributionsAD: turing_chol, symm_turing_chol, _mv_categorical_logpdf, adapt_randn
21+
import ..DistributionsAD: turing_chol, symm_turing_chol, _mv_categorical_logpdf, adapt_randn,
22+
simplex_logpdf
2223
import Base.Broadcast: materialize
2324
import StatsFuns: logsumexp
2425

@@ -47,12 +48,25 @@ using ..DistributionsAD: TuringPoissonBinomial,
4748
TuringDirichlet,
4849
TuringScalMvNormal,
4950
TuringDiagMvNormal,
50-
TuringDenseMvNormal
51+
TuringDenseMvNormal,
52+
VectorOfMultivariate,
53+
FillVectorOfMultivariate
5154

5255
include("reversediffx.jl")
5356

5457
adapt_randn(rng::Random.AbstractRNG, x::TrackedArray, dims...) = adapt_randn(rng, value(x), dims...)
5558

59+
# without this definition tests of `VectorOfMultivariate` with `Dirichlet` fail
60+
# upstream bug caused by `view` + `track`: https://github.com/JuliaDiff/ReverseDiff.jl/pull/164
61+
function _logpdf(dist::VectorOfMultivariate, x::AbstractMatrix{<:TrackedReal})
62+
return sum(i -> _logpdf(dist.dists[i], x[:, i]), axes(x, 2))
63+
end
64+
65+
# fix method ambiguity
66+
function _logpdf(dist::FillVectorOfMultivariate, x::AbstractMatrix{<:TrackedReal})
67+
return loglikelihood(dist.dists.value, x)
68+
end
69+
5670
function PoissonBinomial(p::TrackedArray{<:Real}; check_args=true)
5771
return TuringPoissonBinomial(p; check_args = check_args)
5872
end
@@ -240,36 +254,60 @@ end
240254
# zero mean,, constant variance
241255
MvLogNormal(d::Int, σ::TrackedReal) = TuringMvLogNormal(TuringMvNormal(d, σ))
242256

243-
Dirichlet(alpha::TrackedVector) = TuringDirichlet(alpha)
257+
# Dirichlet
258+
259+
Dirichlet(alpha::AbstractVector{<:TrackedReal}) = TuringDirichlet(alpha)
244260
Dirichlet(d::Integer, alpha::TrackedReal) = TuringDirichlet(d, alpha)
245261

262+
function _logpdf(d::Dirichlet, x::AbstractVector{<:TrackedReal})
263+
return _logpdf(TuringDirichlet(d.alpha, d.alpha0, d.lmnB), x)
264+
end
265+
function logpdf(d::Dirichlet, x::AbstractMatrix{<:TrackedReal})
266+
return logpdf(TuringDirichlet(d.alpha, d.alpha0, d.lmnB), x)
267+
end
268+
function loglikelihood(d::Dirichlet, x::AbstractMatrix{<:TrackedReal})
269+
return loglikelihood(TuringDirichlet(d.alpha, d.alpha0, d.lmnB), x)
270+
end
271+
272+
# default definition of `loglikelihood` yields gradients of zero?!
273+
# upstream bug caused by `view` + `track`: https://github.com/JuliaDiff/ReverseDiff.jl/pull/164
274+
function loglikelihood(d::TuringDirichlet, x::AbstractMatrix{<:TrackedReal})
275+
return sum(i -> logpdf(d, x[:, i]), axes(x, 2))
276+
end
277+
246278
for func_header in [
247-
:(simplex_logpdf(alpha::TrackedVector, lmnB::Real, x::AbstractVector)),
279+
:(simplex_logpdf(alpha::AbstractVector{<:TrackedReal}, lmnB::Real, x::AbstractVector)),
248280
:(simplex_logpdf(alpha::AbstractVector, lmnB::TrackedReal, x::AbstractVector)),
249-
:(simplex_logpdf(alpha::AbstractVector, lmnB::Real, x::TrackedVector)),
250-
:(simplex_logpdf(alpha::TrackedVector, lmnB::TrackedReal, x::AbstractVector)),
251-
:(simplex_logpdf(alpha::AbstractVector, lmnB::TrackedReal, x::TrackedVector)),
252-
:(simplex_logpdf(alpha::TrackedVector, lmnB::Real, x::TrackedVector)),
253-
:(simplex_logpdf(alpha::TrackedVector, lmnB::TrackedReal, x::TrackedVector)),
281+
:(simplex_logpdf(alpha::AbstractVector, lmnB::Real, x::AbstractVector{<:TrackedReal})),
282+
:(simplex_logpdf(alpha::AbstractVector{<:TrackedReal}, lmnB::TrackedReal, x::AbstractVector)),
283+
:(simplex_logpdf(alpha::AbstractVector, lmnB::TrackedReal, x::AbstractVector{<:TrackedReal})),
284+
:(simplex_logpdf(alpha::AbstractVector{<:TrackedReal}, lmnB::Real, x::AbstractVector{<:TrackedReal})),
285+
:(simplex_logpdf(alpha::AbstractVector{<:TrackedReal}, lmnB::TrackedReal, x::AbstractVector{<:TrackedReal})),
254286

255-
:(simplex_logpdf(alpha::TrackedVector, lmnB::Real, x::AbstractMatrix)),
287+
:(simplex_logpdf(alpha::AbstractVector{<:TrackedReal}, lmnB::Real, x::AbstractMatrix)),
256288
:(simplex_logpdf(alpha::AbstractVector, lmnB::TrackedReal, x::AbstractMatrix)),
257-
:(simplex_logpdf(alpha::AbstractVector, lmnB::Real, x::TrackedMatrix)),
258-
:(simplex_logpdf(alpha::TrackedVector, lmnB::TrackedReal, x::AbstractMatrix)),
259-
:(simplex_logpdf(alpha::AbstractVector, lmnB::TrackedReal, x::TrackedMatrix)),
260-
:(simplex_logpdf(alpha::TrackedVector, lmnB::Real, x::TrackedMatrix)),
261-
:(simplex_logpdf(alpha::TrackedVector, lmnB::TrackedReal, x::TrackedMatrix)),
289+
:(simplex_logpdf(alpha::AbstractVector, lmnB::Real, x::AbstractMatrix{<:TrackedReal})),
290+
:(simplex_logpdf(alpha::AbstractVector{<:TrackedReal}, lmnB::TrackedReal, x::AbstractMatrix)),
291+
:(simplex_logpdf(alpha::AbstractVector, lmnB::TrackedReal, x::AbstractMatrix{<:TrackedReal})),
292+
:(simplex_logpdf(alpha::AbstractVector{<:TrackedReal}, lmnB::Real, x::AbstractMatrix{<:TrackedReal})),
293+
:(simplex_logpdf(alpha::AbstractVector{<:TrackedReal}, lmnB::TrackedReal, x::AbstractMatrix{<:TrackedReal})),
262294
]
263295
@eval $func_header = track(simplex_logpdf, alpha, lmnB, x)
264296
end
265297
@grad function simplex_logpdf(alpha, lmnB, x::AbstractVector)
266-
simplex_logpdf(value(alpha), value(lmnB), value(x)), Δ -> begin
267-
.* log.(value(x)), -Δ, Δ .* (value(alpha) .- 1))
298+
_alpha = value(alpha)
299+
_lmnB = value(lmnB)
300+
_x = value(x)
301+
simplex_logpdf(_alpha, _lmnB, _x), Δ -> begin
302+
.* log.(_x), -Δ, Δ .* (_alpha .- 1) ./ _x)
268303
end
269304
end
270305
@grad function simplex_logpdf(alpha, lmnB, x::AbstractMatrix)
271-
simplex_logpdf(value(alpha), value(lmnB), value(x)), Δ -> begin
272-
(log.(value(x)) * Δ, -sum(Δ), repeat(value(alpha) .- 1, 1, size(x, 2)) * Diagonal(Δ))
306+
_alpha = value(alpha)
307+
_lmnB = value(lmnB)
308+
_x = value(x)
309+
simplex_logpdf(_alpha, _lmnB, _x), Δ -> begin
310+
(log.(_x) * Δ, -sum(Δ), ((_alpha .- 1) ./ _x) * Diagonal(Δ))
273311
end
274312
end
275313

test/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ Combinatorics = "1.0.2"
2121
Distributions = "0.24.3"
2222
FiniteDifferences = "0.11.3, 0.12"
2323
ForwardDiff = "0.10.12"
24-
NNlib = "0.7.7"
24+
NNlib = "0.7.10"
2525
PDMats = "0.10.1"
2626
ReverseDiff = "1.4.4"
2727
StatsBase = "0.33.2"

test/ad/distributions.jl

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,6 @@
2626
to_positive(x) = exp.(x)
2727
to_positive(x::AbstractArray{<:AbstractArray}) = to_positive.(x)
2828

29-
# Create vectors in probability simplex.
30-
to_simplex(x::AbstractArray; dims=1) = NNlib.softmax(x; dims=dims)
31-
to_simplex(x::AbstractArray{<:AbstractArray}; dims=1) = to_simplex.(x; dims=dims)
32-
3329
# Tests that have a `broken` field can be executed but, according to FiniteDifferences,
3430
# fail to produce the correct result. These tests can be checked with `@test_broken`.
3531
univariate_distributions = DistSpec[

test/runtests.jl

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,11 +45,32 @@ if GROUP == "All" || GROUP == "AD"
4545
to_posdef(A::AbstractMatrix) = A * A' + I
4646
to_posdef_diagonal(a::AbstractVector) = Diagonal(a.^2 .+ 1)
4747

48+
# Create vectors in probability simplex.
49+
to_simplex(x::AbstractArray) = NNlib.softmax(x; dims=1)
50+
to_simplex(x::AbstractArray{<:AbstractArray}) = to_simplex.(x)
51+
52+
if AD == "All" || AD == "ReverseDiff"
53+
@eval begin
54+
# Define adjoint for ReverseDiff
55+
function to_simplex(x::AbstractArray{<:ReverseDiff.TrackedReal})
56+
return ReverseDiff.track(to_simplex, x)
57+
end
58+
ReverseDiff.@grad function to_simplex(x)
59+
_x = ReverseDiff.value(x)
60+
y = to_simplex(_x)
61+
function pullback(∇)
62+
return (NNlib.∇softmax(∇, _x, y; dims=1),)
63+
end
64+
return y, pullback
65+
end
66+
end
67+
end
68+
4869
if AD == "All" || AD == "Tracker"
4970
@eval begin
5071
# Define adjoints for Tracker
51-
to_posdef(A::TrackedMatrix) = Tracker.track(to_posdef, A)
52-
Tracker.@grad function to_posdef(A::TrackedMatrix)
72+
to_posdef(A::Tracker.TrackedMatrix) = Tracker.track(to_posdef, A)
73+
Tracker.@grad function to_posdef(A::Tracker.TrackedMatrix)
5374
data_A = Tracker.data(A)
5475
S = data_A * data_A' + I
5576
function pullback(∇)

0 commit comments

Comments
 (0)