Skip to content

Commit 8713016

Browse files
authored
Merge pull request #25 from TuringLang/mt/zygote_multi_and_matrix
Zygote support
2 parents 9267b79 + c899e16 commit 8713016

File tree

11 files changed

+607
-181
lines changed

11 files changed

+607
-181
lines changed

.travis.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ julia:
88
- 1.0
99
- 1.1
1010
- 1.2
11+
- 1.3
1112
- nightly
1213

1314
matrix:

Project.toml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,11 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
99
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1010
PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150"
1111
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
12+
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
1213
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
1314
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
1415
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
16+
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
1517

1618
[compat]
1719
Combinatorics = "0.7"
@@ -20,12 +22,12 @@ ForwardDiff = "0.10.6"
2022
PDMats = "0.9"
2123
StatsFuns = "0.8, 0.9"
2224
Tracker = "0.2.5"
23-
Zygote = "0.4.1"
25+
Zygote = "0.4.7"
2426
julia = "1"
2527

2628
[extras]
2729
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
2830
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2931

3032
[targets]
31-
test = ["Test", "FiniteDifferences"]
33+
test = ["Test", "FiniteDifferences"]

README.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
[![Coverage Status](https://coveralls.io/repos/github/TuringLang/DistributionsAD.jl/badge.svg?branch=master)](https://coveralls.io/github/TuringLang/DistributionsAD.jl?branch=master)
66

77

8-
This package defines the necessary functions to enable automatic differentiation (AD) of the `logpdf` function from [Distributions.jl](https://github.com/JuliaStats/Distributions.jl) using the packages [Tracker.jl](https://github.com/FluxML/Tracker.jl) and [ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl). The goal of this package is to make the output of `logpdf` differentiable wrt all continuous parameters of a distribution as well as the random variable in the case of continuous distributions.
8+
This package defines the necessary functions to enable automatic differentiation (AD) of the `logpdf` function from [Distributions.jl](https://github.com/JuliaStats/Distributions.jl) using the packages [Tracker.jl](https://github.com/FluxML/Tracker.jl), [Zygote.jl](https://github.com/FluxML/Zygote.jl) and [ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl). The goal of this package is to make the output of `logpdf` differentiable wrt all continuous parameters of a distribution as well as the random variable in the case of continuous distributions.
99

1010
AD of `logpdf` is fully supported and tested for the following distributions wrt all combinations of continuous variables (distribution parameters and/or the random variable) and using all defined distribution constructors:
1111
- Univariate discrete
@@ -24,7 +24,6 @@ AD of `logpdf` is fully supported and tested for the following distributions wrt
2424
- `BetaPrime`
2525
- `Biweight`
2626
- `Cauchy`
27-
- `Chernoff`
2827
- `Chi`
2928
- `Chisq`
3029
- `Cosine`
@@ -64,6 +63,8 @@ AD of `logpdf` is fully supported and tested for the following distributions wrt
6463
- `MvNormal`
6564
- Matrix-variate continuous
6665
- `MatrixBeta`
66+
- `Wishart`
67+
- `InverseWishart`
6768

6869
# Get Involved
6970

src/DistributionsAD.jl

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,16 @@ module DistributionsAD
33
using PDMats,
44
ForwardDiff,
55
Zygote,
6-
Tracker,
76
LinearAlgebra,
87
Distributions,
98
Random,
10-
Combinatorics
9+
Combinatorics,
10+
SpecialFunctions,
11+
StatsFuns
1112

12-
using Tracker: TrackedReal
13+
using Tracker: Tracker, TrackedReal, TrackedVector, TrackedMatrix, TrackedArray,
14+
TrackedVecOrMat, track, data
15+
using ZygoteRules: ZygoteRules, pullback
1316
using LinearAlgebra: copytri!
1417
using Distributions: AbstractMvLogNormal,
1518
ContinuousMultivariateDistribution
@@ -30,10 +33,13 @@ export TuringScalMvNormal,
3033
TuringDiagMvNormal,
3134
TuringDenseMvNormal,
3235
TuringMvLogNormal,
33-
TuringPoissonBinomial
36+
TuringPoissonBinomial,
37+
TuringWishart,
38+
TuringInverseWishart
3439

3540
include("common.jl")
3641
include("univariate.jl")
3742
include("multivariate.jl")
43+
include("matrixvariate.jl")
3844

3945
end

src/common.jl

Lines changed: 29 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,66 +1,65 @@
11
## Generic ##
22

33
function Base.fill(
4-
value::Tracker.TrackedReal,
4+
value::TrackedReal,
55
dims::Vararg{Union{Integer, AbstractUnitRange}},
66
)
7-
return Tracker.track(fill, value, dims...)
7+
return track(fill, value, dims...)
88
end
99
Tracker.@grad function Base.fill(value::Real, dims...)
10-
return fill(Tracker.data(value), dims...), function(Δ)
10+
return fill(data(value), dims...), function(Δ)
1111
size(Δ) dims && error("Dimension mismatch")
1212
return (sum(Δ), map(_->nothing, dims)...)
1313
end
1414
end
1515

1616
## StatsFuns ##
1717

18-
logsumexp(x::Tracker.TrackedArray) = Tracker.track(logsumexp, x)
19-
Tracker.@grad function logsumexp(x::Tracker.TrackedArray)
20-
lse = logsumexp(Tracker.data(x))
21-
return lse,
22-
Δ->.* exp.(x .- lse),)
18+
logsumexp(x::TrackedArray) = track(logsumexp, x)
19+
Tracker.@grad function logsumexp(x::TrackedArray)
20+
lse = logsumexp(data(x))
21+
return lse, Δ ->.* exp.(x .- lse),)
2322
end
2423

2524
## Linear algebra ##
2625

27-
LinearAlgebra.UpperTriangular(A::Tracker.TrackedMatrix) = Tracker.track(UpperTriangular, A)
26+
LinearAlgebra.UpperTriangular(A::TrackedMatrix) = track(UpperTriangular, A)
2827
Tracker.@grad function LinearAlgebra.UpperTriangular(A::AbstractMatrix)
29-
return UpperTriangular(Tracker.data(A)), Δ->(UpperTriangular(Δ),)
28+
return UpperTriangular(data(A)), Δ->(UpperTriangular(Δ),)
3029
end
3130

32-
function LinearAlgebra.cholesky(A::Tracker.TrackedMatrix; check=true)
31+
function LinearAlgebra.cholesky(A::TrackedMatrix; check=true)
3332
factors_info = turing_chol(A, check)
3433
factors = factors_info[1]
35-
info = Tracker.data(factors_info[2])
34+
info = data(factors_info[2])
3635
return Cholesky{eltype(factors), typeof(factors)}(factors, 'U', info)
3736
end
3837
function turing_chol(A::AbstractMatrix, check)
3938
chol = cholesky(A, check=check)
4039
(chol.factors, chol.info)
4140
end
42-
turing_chol(A::Tracker.TrackedMatrix, check) = Tracker.track(turing_chol, A, check)
41+
turing_chol(A::TrackedMatrix, check) = track(turing_chol, A, check)
4342
Tracker.@grad function turing_chol(A::AbstractMatrix, check)
44-
C, back = Zygote.pullback(unsafe_cholesky, Tracker.data(A), Tracker.data(check))
45-
return (C.factors, C.info), Δ->back((factors=Tracker.data(Δ[1]),))
43+
C, back = pullback(unsafe_cholesky, data(A), data(check))
44+
return (C.factors, C.info), Δ->back((factors=data(Δ[1]),))
4645
end
4746

4847
unsafe_cholesky(x, check) = cholesky(x, check=check)
49-
Zygote.@adjoint function unsafe_cholesky::Real, check)
48+
ZygoteRules.@adjoint function unsafe_cholesky::Real, check)
5049
C = cholesky(Σ; check=check)
5150
return C, function::NamedTuple)
5251
issuccess(C) || return (zero(Σ), nothing)
5352
.factors[1, 1] / (2 * C.U[1, 1]), nothing)
5453
end
5554
end
56-
Zygote.@adjoint function unsafe_cholesky::Diagonal, check)
55+
ZygoteRules.@adjoint function unsafe_cholesky::Diagonal, check)
5756
C = cholesky(Σ; check=check)
5857
return C, function::NamedTuple)
5958
issuccess(C) || (Diagonal(zero(diag.factors))), nothing)
6059
(Diagonal(diag.factors) .* inv.(2 .* C.factors.diag)), nothing)
6160
end
6261
end
63-
Zygote.@adjoint function unsafe_cholesky::Union{StridedMatrix, Symmetric{<:Real, <:StridedMatrix}}, check)
62+
ZygoteRules.@adjoint function unsafe_cholesky::Union{StridedMatrix, Symmetric{<:Real, <:StridedMatrix}}, check)
6463
C = cholesky(Σ; check=check)
6564
return C, function::NamedTuple)
6665
issuccess(C) || return (zero.factors), nothing)
@@ -75,38 +74,33 @@ Zygote.@adjoint function unsafe_cholesky(Σ::Union{StridedMatrix, Symmetric{<:Re
7574
return (UpperTriangular(Σ̄), nothing)
7675
end
7776
end
78-
77+
7978
# Specialised logdet for cholesky to target the triangle directly.
8079
logdet_chol_tri(U::AbstractMatrix) = 2 * sum(log, U[diagind(U)])
81-
logdet_chol_tri(U::Tracker.TrackedMatrix) = Tracker.track(logdet_chol_tri, U)
80+
logdet_chol_tri(U::TrackedMatrix) = track(logdet_chol_tri, U)
8281
Tracker.@grad function logdet_chol_tri(U::AbstractMatrix)
83-
U_data = Tracker.data(U)
82+
U_data = data(U)
8483
return logdet_chol_tri(U_data), Δ->(Matrix(Diagonal(2 .* Δ ./ diag(U_data))),)
8584
end
8685

87-
function LinearAlgebra.logdet(C::Cholesky{<:Tracker.TrackedReal, <:Tracker.TrackedMatrix})
86+
function LinearAlgebra.logdet(C::Cholesky{<:TrackedReal, <:TrackedMatrix})
8887
return logdet_chol_tri(C.U)
8988
end
9089

9190
# Tracker's implementation of ldiv isn't good. We'll use Zygote's instead.
9291
zygote_ldiv(A::AbstractMatrix, B::AbstractVecOrMat) = A \ B
93-
function zygote_ldiv(A::Tracker.TrackedMatrix, B::Tracker.TrackedVecOrMat)
94-
return Tracker.track(zygote_ldiv, A, B)
92+
function zygote_ldiv(A::TrackedMatrix, B::TrackedVecOrMat)
93+
return track(zygote_ldiv, A, B)
9594
end
96-
function zygote_ldiv(A::Tracker.TrackedMatrix, B::AbstractVecOrMat)
97-
return Tracker.track(zygote_ldiv, A, B)
95+
function zygote_ldiv(A::TrackedMatrix, B::AbstractVecOrMat)
96+
return track(zygote_ldiv, A, B)
9897
end
99-
zygote_ldiv(A::AbstractMatrix, B::Tracker.TrackedVecOrMat) = Tracker.track(zygote_ldiv, A, B)
98+
zygote_ldiv(A::AbstractMatrix, B::TrackedVecOrMat) = track(zygote_ldiv, A, B)
10099
Tracker.@grad function zygote_ldiv(A, B)
101-
Y, back = Zygote.pullback(\, Tracker.data(A), Tracker.data(B))
102-
return Y, Δ->back(Tracker.data(Δ))
100+
Y, back = pullback(\, data(A), data(B))
101+
return Y, Δ->back(data(Δ))
103102
end
104103

105-
function Base.:\(a::Cholesky{<:Tracker.TrackedReal, <:Tracker.TrackedArray}, b::AbstractVecOrMat)
104+
function Base.:\(a::Cholesky{<:TrackedReal, <:TrackedArray}, b::AbstractVecOrMat)
106105
return (a.U \ (a.U' \ b))
107106
end
108-
109-
## PDMats ##
110-
111-
PDMats.invquad::PDiagMat, x::Tracker.TrackedVector) = sum(abs2.(x) ./ Σ.diag)
112-
PDMats.invquad::PDMat, x::Tracker.TrackedVector) = sum(abs2, zygote_ldiv.chol.L, x))

0 commit comments

Comments
 (0)