Skip to content

Commit fe3f010

Browse files
committed
rebase issues
1 parent 4b458df commit fe3f010

File tree

5 files changed

+65
-71
lines changed

5 files changed

+65
-71
lines changed

src/DistributionsAD.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,16 @@ module DistributionsAD
33
using PDMats,
44
ForwardDiff,
55
Zygote,
6-
ZygoteRules,
7-
Tracker,
86
LinearAlgebra,
97
Distributions,
108
Random,
119
Combinatorics,
1210
SpecialFunctions,
1311
StatsFuns
1412

15-
using Tracker: TrackedReal
13+
using Tracker: Tracker, TrackedReal, TrackedVector, TrackedMatrix, TrackedArray,
14+
TrackedVecOrMat, track, data
15+
using ZygoteRules: ZygoteRules, pullback
1616
using LinearAlgebra: copytri!
1717
using Distributions: AbstractMvLogNormal,
1818
ContinuousMultivariateDistribution

src/common.jl

Lines changed: 26 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,48 +1,47 @@
11
## Generic ##
22

33
function Base.fill(
4-
value::Tracker.TrackedReal,
4+
value::TrackedReal,
55
dims::Vararg{Union{Integer, AbstractUnitRange}},
66
)
7-
return Tracker.track(fill, value, dims...)
7+
return track(fill, value, dims...)
88
end
99
Tracker.@grad function Base.fill(value::Real, dims...)
10-
return fill(Tracker.data(value), dims...), function(Δ)
10+
return fill(data(value), dims...), function(Δ)
1111
size(Δ) dims && error("Dimension mismatch")
1212
return (sum(Δ), map(_->nothing, dims)...)
1313
end
1414
end
1515

1616
## StatsFuns ##
1717

18-
logsumexp(x::Tracker.TrackedArray) = Tracker.track(logsumexp, x)
19-
Tracker.@grad function logsumexp(x::Tracker.TrackedArray)
20-
lse = logsumexp(Tracker.data(x))
21-
return lse,
22-
Δ->.* exp.(x .- lse),)
18+
logsumexp(x::TrackedArray) = track(logsumexp, x)
19+
Tracker.@grad function logsumexp(x::TrackedArray)
20+
lse = logsumexp(data(x))
21+
return lse, Δ ->.* exp.(x .- lse),)
2322
end
2423

2524
## Linear algebra ##
2625

27-
LinearAlgebra.UpperTriangular(A::Tracker.TrackedMatrix) = Tracker.track(UpperTriangular, A)
26+
LinearAlgebra.UpperTriangular(A::TrackedMatrix) = track(UpperTriangular, A)
2827
Tracker.@grad function LinearAlgebra.UpperTriangular(A::AbstractMatrix)
29-
return UpperTriangular(Tracker.data(A)), Δ->(UpperTriangular(Δ),)
28+
return UpperTriangular(data(A)), Δ->(UpperTriangular(Δ),)
3029
end
3130

32-
function LinearAlgebra.cholesky(A::Tracker.TrackedMatrix; check=true)
31+
function LinearAlgebra.cholesky(A::TrackedMatrix; check=true)
3332
factors_info = turing_chol(A, check)
3433
factors = factors_info[1]
35-
info = Tracker.data(factors_info[2])
34+
info = data(factors_info[2])
3635
return Cholesky{eltype(factors), typeof(factors)}(factors, 'U', info)
3736
end
3837
function turing_chol(A::AbstractMatrix, check)
3938
chol = cholesky(A, check=check)
4039
(chol.factors, chol.info)
4140
end
42-
turing_chol(A::Tracker.TrackedMatrix, check) = Tracker.track(turing_chol, A, check)
41+
turing_chol(A::TrackedMatrix, check) = track(turing_chol, A, check)
4342
Tracker.@grad function turing_chol(A::AbstractMatrix, check)
44-
C, back = ZygoteRules.pullback(unsafe_cholesky, Tracker.data(A), Tracker.data(check))
45-
return (C.factors, C.info), Δ->back((factors=Tracker.data(Δ[1]),))
43+
C, back = pullback(unsafe_cholesky, data(A), data(check))
44+
return (C.factors, C.info), Δ->back((factors=data(Δ[1]),))
4645
end
4746

4847
unsafe_cholesky(x, check) = cholesky(x, check=check)
@@ -75,38 +74,33 @@ ZygoteRules.@adjoint function unsafe_cholesky(Σ::Union{StridedMatrix, Symmetric
7574
return (UpperTriangular(Σ̄), nothing)
7675
end
7776
end
78-
77+
7978
# Specialised logdet for cholesky to target the triangle directly.
8079
logdet_chol_tri(U::AbstractMatrix) = 2 * sum(log, U[diagind(U)])
81-
logdet_chol_tri(U::Tracker.TrackedMatrix) = Tracker.track(logdet_chol_tri, U)
80+
logdet_chol_tri(U::TrackedMatrix) = track(logdet_chol_tri, U)
8281
Tracker.@grad function logdet_chol_tri(U::AbstractMatrix)
83-
U_data = Tracker.data(U)
82+
U_data = data(U)
8483
return logdet_chol_tri(U_data), Δ->(Matrix(Diagonal(2 .* Δ ./ diag(U_data))),)
8584
end
8685

87-
function LinearAlgebra.logdet(C::Cholesky{<:Tracker.TrackedReal, <:Tracker.TrackedMatrix})
86+
function LinearAlgebra.logdet(C::Cholesky{<:TrackedReal, <:TrackedMatrix})
8887
return logdet_chol_tri(C.U)
8988
end
9089

9190
# Tracker's implementation of ldiv isn't good. We'll use Zygote's instead.
9291
zygote_ldiv(A::AbstractMatrix, B::AbstractVecOrMat) = A \ B
93-
function zygote_ldiv(A::Tracker.TrackedMatrix, B::Tracker.TrackedVecOrMat)
94-
return Tracker.track(zygote_ldiv, A, B)
92+
function zygote_ldiv(A::TrackedMatrix, B::TrackedVecOrMat)
93+
return track(zygote_ldiv, A, B)
9594
end
96-
function zygote_ldiv(A::Tracker.TrackedMatrix, B::AbstractVecOrMat)
97-
return Tracker.track(zygote_ldiv, A, B)
95+
function zygote_ldiv(A::TrackedMatrix, B::AbstractVecOrMat)
96+
return track(zygote_ldiv, A, B)
9897
end
99-
zygote_ldiv(A::AbstractMatrix, B::Tracker.TrackedVecOrMat) = Tracker.track(zygote_ldiv, A, B)
98+
zygote_ldiv(A::AbstractMatrix, B::TrackedVecOrMat) = track(zygote_ldiv, A, B)
10099
Tracker.@grad function zygote_ldiv(A, B)
101-
Y, back = ZygoteRules.pullback(\, Tracker.data(A), Tracker.data(B))
102-
return Y, Δ->back(Tracker.data(Δ))
100+
Y, back = pullback(\, data(A), data(B))
101+
return Y, Δ->back(data(Δ))
103102
end
104103

105-
function Base.:\(a::Cholesky{<:Tracker.TrackedReal, <:Tracker.TrackedArray}, b::AbstractVecOrMat)
104+
function Base.:\(a::Cholesky{<:TrackedReal, <:TrackedArray}, b::AbstractVecOrMat)
106105
return (a.U \ (a.U' \ b))
107106
end
108-
109-
## PDMats ##
110-
111-
PDMats.invquad::PDiagMat, x::Tracker.TrackedVector) = sum(abs2.(x) ./ Σ.diag)
112-
PDMats.invquad::PDMat, x::Tracker.TrackedVector) = sum(abs2, zygote_ldiv.chol.L, x))

src/matrixvariate.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -200,10 +200,10 @@ Distributions._rand!(rng::AbstractRNG, d::TuringInverseWishart, A::AbstractMatri
200200
## Adjoints
201201

202202
ZygoteRules.@adjoint function Distributions.Wishart(df::Real, S::AbstractMatrix{<:Real})
203-
return ZygoteRules.pullback(TuringWishart, df, S)
203+
return pullback(TuringWishart, df, S)
204204
end
205205
ZygoteRules.@adjoint function Distributions.InverseWishart(df::Real, S::AbstractMatrix{<:Real})
206-
return ZygoteRules.pullback(TuringInverseWishart, df, S)
206+
return pullback(TuringInverseWishart, df, S)
207207
end
208208

209209
Distributions.Wishart(df::TrackedReal, S::Matrix{<:Real}) = TuringWishart(df, S)

src/multivariate.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ end
167167
for T in (:AbstractVector, :AbstractMatrix)
168168
@eval Distributions.logpdf(d::TuringMvLogNormal, x::$T) = _logpdf(d, x)
169169
end
170-
for T in (:(Tracker.TrackedVector), :(Tracker.TrackedMatrix))
170+
for T in (:TrackedVector, :TrackedMatrix)
171171
@eval Distributions.logpdf(d::TuringMvLogNormal, x::$T) = _logpdf(d, x)
172172
end
173173
function _logpdf(d::TuringMvLogNormal, x::AbstractVector{T}) where {T<:Real}
@@ -248,18 +248,18 @@ MvLogNormal(d::Int, σ::TrackedReal{<:Real}) = TuringMvLogNormal(TuringMvNormal(
248248
ZygoteRules.@adjoint function Distributions.MvNormal(
249249
A::Union{AbstractVector{<:Real}, AbstractMatrix{<:Real}},
250250
)
251-
return ZygoteRules.pullback(TuringMvNormal, A)
251+
return pullback(TuringMvNormal, A)
252252
end
253253
ZygoteRules.@adjoint function Distributions.MvNormal(
254254
m::AbstractVector{<:Real},
255255
A::Union{Real, UniformScaling, AbstractVecOrMat{<:Real}},
256256
)
257-
return ZygoteRules.pullback(TuringMvNormal, m, A)
257+
return pullback(TuringMvNormal, m, A)
258258
end
259259
ZygoteRules.@adjoint function Distributions.MvNormal(
260260
d::Int,
261261
A::Real,
262262
)
263-
value, back = ZygoteRules.pullback(A -> TuringMvNormal(d, A), A)
263+
value, back = pullback(A -> TuringMvNormal(d, A), A)
264264
return value, x -> (nothing, back(x)[1])
265265
end

src/univariate.jl

Lines changed: 30 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,11 @@ function uniformlogpdf(a, b, x)
2727
return oftype(c, -Inf)
2828
end
2929
end
30-
uniformlogpdf(a::Real, b::Real, x::TrackedReal) = Tracker.track(uniformlogpdf, a, b, x)
31-
uniformlogpdf(a::TrackedReal, b::TrackedReal, x::Real) = Tracker.track(uniformlogpdf, a, b, x)
32-
uniformlogpdf(a::TrackedReal, b::TrackedReal, x::TrackedReal) = Tracker.track(uniformlogpdf, a, b, x)
30+
uniformlogpdf(a::Real, b::Real, x::TrackedReal) = track(uniformlogpdf, a, b, x)
31+
uniformlogpdf(a::TrackedReal, b::TrackedReal, x::Real) = track(uniformlogpdf, a, b, x)
32+
uniformlogpdf(a::TrackedReal, b::TrackedReal, x::TrackedReal) = track(uniformlogpdf, a, b, x)
3333
Tracker.@grad function uniformlogpdf(a, b, x)
34-
diff = Tracker.data(b) - Tracker.data(a)
34+
diff = data(b) - data(a)
3535
T = typeof(diff)
3636
l = -log(diff)
3737
f = isfinite(l)
@@ -49,7 +49,7 @@ ZygoteRules.@adjoint function uniformlogpdf(a, b, x)
4949
return l, Δ->(f ? da : n, f ? -da : n, f ? zero(T) : n)
5050
end
5151
ZygoteRules.@adjoint function Distributions.Uniform(args...)
52-
return ZygoteRules.pullback(TuringUniform, args...)
52+
return pullback(TuringUniform, args...)
5353
end
5454

5555
## Beta ##
@@ -123,12 +123,12 @@ end
123123
logpdf(d::Semicircle{<:Real}, x::TrackedReal) = semicirclelogpdf(d.r, x)
124124
logpdf(d::Semicircle{<:TrackedReal}, x::Real) = semicirclelogpdf(d.r, x)
125125
logpdf(d::Semicircle{<:TrackedReal}, x::TrackedReal) = semicirclelogpdf(d.r, x)
126-
semicirclelogpdf(r::TrackedReal, x::Real) = Tracker.track(semicirclelogpdf, r, x)
127-
semicirclelogpdf(r::Real, x::TrackedReal) = Tracker.track(semicirclelogpdf, r, x)
128-
semicirclelogpdf(r::TrackedReal, x::TrackedReal) = Tracker.track(semicirclelogpdf, r, x)
126+
semicirclelogpdf(r::TrackedReal, x::Real) = track(semicirclelogpdf, r, x)
127+
semicirclelogpdf(r::Real, x::TrackedReal) = track(semicirclelogpdf, r, x)
128+
semicirclelogpdf(r::TrackedReal, x::TrackedReal) = track(semicirclelogpdf, r, x)
129129
Tracker.@grad function semicirclelogpdf(r, x)
130-
rd = Tracker.data(r)
131-
xd = Tracker.data(x)
130+
rd = data(r)
131+
xd = data(x)
132132
xx, rr = promote(xd, float(rd))
133133
d = Semicircle(rr)
134134
T = typeof(xx)
@@ -146,9 +146,9 @@ end
146146

147147
## Binomial ##
148148

149-
binomlogpdf(n::Int, p::Tracker.TrackedReal, x::Int) = Tracker.track(binomlogpdf, n, p, x)
150-
Tracker.@grad function binomlogpdf(n::Int, p::Tracker.TrackedReal, x::Int)
151-
return binomlogpdf(n, Tracker.data(p), x),
149+
binomlogpdf(n::Int, p::TrackedReal, x::Int) = track(binomlogpdf, n, p, x)
150+
Tracker.@grad function binomlogpdf(n::Int, p::TrackedReal, x::Int)
151+
return binomlogpdf(n, data(p), x),
152152
Δ->(nothing, Δ * (x / p - (n - x) / (1 - p)), nothing)
153153
end
154154
ZygoteRules.@adjoint function binomlogpdf(n::Int, p::Real, x::Int)
@@ -171,19 +171,19 @@ end
171171
_nbinomlogpdf_grad_1(r, p, k) = k == 0 ? log(p) : sum(1 / (k + r - i) for i in 1:k) + log(p)
172172
_nbinomlogpdf_grad_2(r, p, k) = -k / (1 - p) + r / p
173173

174-
nbinomlogpdf(n::Tracker.TrackedReal, p::Tracker.TrackedReal, x::Int) = Tracker.track(nbinomlogpdf, n, p, x)
175-
nbinomlogpdf(n::Real, p::Tracker.TrackedReal, x::Int) = Tracker.track(nbinomlogpdf, n, p, x)
176-
nbinomlogpdf(n::Tracker.TrackedReal, p::Real, x::Int) = Tracker.track(nbinomlogpdf, n, p, x)
177-
Tracker.@grad function nbinomlogpdf(r::Tracker.TrackedReal, p::Tracker.TrackedReal, k::Int)
178-
return nbinomlogpdf(Tracker.data(r), Tracker.data(p), k),
174+
nbinomlogpdf(n::TrackedReal, p::TrackedReal, x::Int) = track(nbinomlogpdf, n, p, x)
175+
nbinomlogpdf(n::Real, p::TrackedReal, x::Int) = track(nbinomlogpdf, n, p, x)
176+
nbinomlogpdf(n::TrackedReal, p::Real, x::Int) = track(nbinomlogpdf, n, p, x)
177+
Tracker.@grad function nbinomlogpdf(r::TrackedReal, p::TrackedReal, k::Int)
178+
return nbinomlogpdf(data(r), data(p), k),
179179
Δ->* _nbinomlogpdf_grad_1(r, p, k), Δ * _nbinomlogpdf_grad_2(r, p, k), nothing)
180180
end
181-
Tracker.@grad function nbinomlogpdf(r::Real, p::Tracker.TrackedReal, k::Int)
182-
return nbinomlogpdf(Tracker.data(r), Tracker.data(p), k),
181+
Tracker.@grad function nbinomlogpdf(r::Real, p::TrackedReal, k::Int)
182+
return nbinomlogpdf(data(r), data(p), k),
183183
Δ->(Tracker._zero(r), Δ * _nbinomlogpdf_grad_2(r, p, k), nothing)
184184
end
185-
Tracker.@grad function nbinomlogpdf(r::Tracker.TrackedReal, p::Real, k::Int)
186-
return nbinomlogpdf(Tracker.data(r), Tracker.data(p), k),
185+
Tracker.@grad function nbinomlogpdf(r::TrackedReal, p::Real, k::Int)
186+
return nbinomlogpdf(data(r), data(p), k),
187187
Δ->* _nbinomlogpdf_grad_1(r, p, k), Tracker._zero(p), nothing)
188188
end
189189

@@ -212,9 +212,9 @@ end
212212

213213
## Poisson ##
214214

215-
poislogpdf(v::Tracker.TrackedReal, x::Int) = Tracker.track(poislogpdf, v, x)
216-
Tracker.@grad function poislogpdf(v::Tracker.TrackedReal, x::Int)
217-
return poislogpdf(Tracker.data(v), x),
215+
poislogpdf(v::TrackedReal, x::Int) = track(poislogpdf, v, x)
216+
Tracker.@grad function poislogpdf(v::TrackedReal, x::Int)
217+
return poislogpdf(data(v), x),
218218
Δ->* (x/v - 1), nothing)
219219
end
220220
ZygoteRules.@adjoint function poislogpdf(v::Real, x::Int)
@@ -244,13 +244,13 @@ function logpdf(d::TuringPoissonBinomial{T}, k::Int) where T<:Real
244244
insupport(d, k) ? log(d.pmf[k + 1]) : -T(Inf)
245245
end
246246
quantile(d::TuringPoissonBinomial, x::Float64) = quantile(Categorical(d.pmf), x) - 1
247-
PoissonBinomial(p::Tracker.TrackedArray) = TuringPoissonBinomial(p)
247+
PoissonBinomial(p::TrackedArray) = TuringPoissonBinomial(p)
248248
Base.minimum(d::TuringPoissonBinomial) = 0
249249
Base.maximum(d::TuringPoissonBinomial) = length(d.p)
250250

251-
poissonbinomial_pdf_fft(x::Tracker.TrackedArray) = Tracker.track(poissonbinomial_pdf_fft, x)
252-
Tracker.@grad function poissonbinomial_pdf_fft(x::Tracker.TrackedArray)
253-
x_data = Tracker.data(x)
251+
poissonbinomial_pdf_fft(x::TrackedArray) = track(poissonbinomial_pdf_fft, x)
252+
Tracker.@grad function poissonbinomial_pdf_fft(x::TrackedArray)
253+
x_data = data(x)
254254
T = eltype(x_data)
255255
fft = poissonbinomial_pdf_fft(x_data)
256256
return fft, Δ -> begin
@@ -268,7 +268,7 @@ end
268268
# The code below doesn't work because of bugs in Zygote. The above is inefficient.
269269
#=
270270
ZygoteRules.@adjoint function poissonbinomial_pdf_fft(x::AbstractArray{<:Real})
271-
return ZygoteRules.pullback(poissonbinomial_pdf_fft_zygote, x)
271+
return pullback(poissonbinomial_pdf_fft_zygote, x)
272272
end
273273
function poissonbinomial_pdf_fft_zygote(p::AbstractArray{T}) where {T <: Real}
274274
n = length(p)

0 commit comments

Comments
 (0)