Skip to content

Commit c463960

Browse files
authored
Update ChainRules definitions and add differential for PoissonBinomial pdf (#162)
1 parent 978b1fe commit c463960

File tree

11 files changed

+213
-125
lines changed

11 files changed

+213
-125
lines changed

Project.toml

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "DistributionsAD"
22
uuid = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
3-
version = "0.6.22"
3+
version = "0.6.23"
44

55
[deps]
66
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
@@ -10,7 +10,6 @@ Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
1010
DiffRules = "b552c78f-8df3-52c6-915a-8e097449b14b"
1111
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
1212
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
13-
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
1413
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1514
NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
1615
PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150"
@@ -30,7 +29,6 @@ Compat = "3.6"
3029
DiffRules = "0.1, 1.0"
3130
Distributions = "0.23.3, 0.24"
3231
FillArrays = "0.8, 0.9, 0.10, 0.11"
33-
ForwardDiff = "0.10.6"
3432
NaNMath = "0.3"
3533
PDMats = "0.9, 0.10, 0.11"
3634
Requires = "1"

src/DistributionsAD.jl

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@ import StatsFuns: logsumexp,
2828
nbetalogpdf
2929
import Distributions: MvNormal,
3030
MvLogNormal,
31-
poissonbinomial_pdf_fft,
3231
logpdf,
3332
quantile,
3433
PoissonBinomial,
@@ -65,9 +64,6 @@ include("zygote.jl")
6564
@require ForwardDiff="f6369f11-7733-5829-9624-2563aa707210" begin
6665
using .ForwardDiff: @define_binary_dual_op # Needed for `eval`ing diffrules here
6766
include("forwarddiff.jl")
68-
69-
# loads adjoint for `poissonbinomial_pdf` and `poissonbinomial_pdf_fft`
70-
include("zygote_forwarddiff.jl")
7167
end
7268

7369
@require ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" begin

src/chainrules.jl

Lines changed: 82 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -11,75 +11,131 @@
1111
(c, -c, z),
1212
)
1313

14+
# StatsFuns: https://github.com/JuliaStats/StatsFuns.jl/pull/106
15+
1416
## Beta ##
1517

1618
@scalar_rule(
1719
betalogpdf::Real, β::Real, x::Number),
18-
@setup(di = digamma+ β)),
20+
@setup(z = digamma+ β)),
1921
(
20-
@thunk(log(x) - digamma) + di),
21-
@thunk(log(1 - x) - digamma) + di),
22-
@thunk((α - 1)/x + (1 - β)/(1 - x)),
22+
log(x) + z - digamma(α),
23+
log1p(-x) + z - digamma(β),
24+
(α - 1) / x + (1 - β) / (1 - x),
2325
),
2426
)
2527

2628
## Gamma ##
2729

2830
@scalar_rule(
2931
gammalogpdf(k::Real, θ::Real, x::Number),
32+
@setup(
33+
invθ = inv(θ),
34+
xoθ = invθ * x,
35+
z = xoθ - k,
36+
),
3037
(
31-
@thunk(-digamma(k) - log(θ) + log(x)),
32-
@thunk(-k/θ + x/θ^2),
33-
@thunk((k - 1)/x - 1/θ),
38+
log(xoθ) - digamma(k),
39+
invθ * z,
40+
- (1 + z) / x,
3441
),
3542
)
3643

3744
## Chisq ##
3845

3946
@scalar_rule(
4047
chisqlogpdf(k::Real, x::Number),
41-
@setup(ko2 = k / 2),
42-
(@thunk((-logtwo - digamma(ko2) + log(x)) / 2), @thunk((ko2 - 1)/x - one(ko2) / 2)),
48+
@setup(hk = k / 2),
49+
(
50+
(log(x) - logtwo - digamma(hk)) / 2,
51+
(hk - 1) / x - one(hk) / 2,
52+
),
4353
)
4454

4555
## FDist ##
4656

4757
@scalar_rule(
48-
fdistlogpdf(v1::Real, v2::Real, x::Number),
58+
fdistlogpdf(ν1::Real, ν2::Real, x::Number),
4959
@setup(
50-
temp1 = v1 * x + v2,
51-
temp2 = log(temp1),
52-
vsum = v1 + v2,
53-
temp3 = vsum / temp1,
54-
temp4 = digamma(vsum / 2),
60+
xν1 = x * ν1,
61+
temp1 = xν1 + ν2,
62+
a = (x - 1) / temp1,
63+
νsum = ν1 + ν2,
64+
di = digamma(νsum / 2),
5565
),
5666
(
57-
@thunk((log(v1 * x) + 1 - temp2 - x * temp3 - digamma(v1 / 2) + temp4) / 2),
58-
@thunk((log(v2) + 1 - temp2 - temp3 - digamma(v2 / 2) + temp4) / 2),
59-
@thunk(v1 / 2 * (1 / x - temp3) - 1 / x),
67+
(-log1p(ν2 / xν1) - ν2 * a + di - digamma(ν1 / 2)) / 2,
68+
(-log1p(xν1 / ν2) + ν1 * a + di - digamma(ν2 / 2)) / 2,
69+
((ν1 - 2) / x - ν1 * νsum / temp1) / 2,
6070
),
6171
)
6272

6373
## TDist ##
6474

6575
@scalar_rule(
66-
tdistlogpdf(v::Real, x::Number),
76+
tdistlogpdf::Real, x::Number),
77+
@setup(
78+
νp1 = ν + 1,
79+
xsq = x^2,
80+
invν = inv(ν),
81+
a = xsq * invν,
82+
b = νp1 /+ xsq),
83+
),
6784
(
68-
@thunk((digamma((v + 1) / 2) - 1 / v - digamma(v / 2) - log(1 + x^2 / v) + x^2 * (v + 1) / v^2 / (1 + x^2 / v)) / 2),
69-
@thunk(-x * (v + 1) / (v + x^2)),
70-
)
85+
(digamma(νp1 / 2) - digamma(ν / 2) + a * b - log1p(a) - invν) / 2,
86+
- x * b,
87+
),
7188
)
7289

7390
## Binomial ##
7491

7592
@scalar_rule(
76-
binomlogpdf(n::Int, p::Real, x::Int),
77-
(DoesNotExist(), x / p - (n - x) / (1 - p), DoesNotExist()),
93+
binomlogpdf(n::Real, p::Real, k::Real),
94+
@setup(z = digamma(n - k + 1)),
95+
(
96+
digamma(n + 2) - z + log1p(-p) - 1 / (1 + n),
97+
(k / p - n) / (1 - p),
98+
z - digamma(k + 1) + logit(p),
99+
),
78100
)
79101

80102
## Poisson ##
81103

82104
@scalar_rule(
83-
poislogpdf(v::Real, x::Int),
84-
(x / v - 1, DoesNotExist()),
105+
poislogpdf::Number, x::Number),
106+
((iszero(x) && iszero(λ) ? zero(x / λ) : x / λ) - 1, log(λ) - digamma(x + 1)),
107+
)
108+
109+
## PoissonBinomial
110+
111+
function ChainRulesCore.rrule(
112+
::typeof(Distributions.poissonbinomial_pdf_fft), p::AbstractVector{<:Real}
85113
)
114+
y = Distributions.poissonbinomial_pdf_fft(p)
115+
A = poissonbinomial_partialderivatives(p)
116+
function poissonbinomial_pdf_fft_pullback(Δy)
117+
= InplaceableThunk(
118+
@thunk(A * Δy),
119+
Δ -> LinearAlgebra.mul!(Δ, A, Δy, true, true),
120+
)
121+
return (NO_FIELDS, p̄)
122+
end
123+
return y, poissonbinomial_pdf_fft_pullback
124+
end
125+
126+
if isdefined(Distributions, :poissonbinomial_pdf)
127+
function ChainRulesCore.rrule(
128+
::typeof(Distributions.poissonbinomial_pdf), p::AbstractVector{<:Real}
129+
)
130+
y = Distributions.poissonbinomial_pdf(p)
131+
A = poissonbinomial_partialderivatives(p)
132+
function poissonbinomial_pdf_pullback(Δy)
133+
= InplaceableThunk(
134+
@thunk(A * Δy),
135+
Δ -> LinearAlgebra.mul!(Δ, A, Δy, true, true),
136+
)
137+
return (NO_FIELDS, p̄)
138+
end
139+
return y, poissonbinomial_pdf_pullback
140+
end
141+
end

src/common.jl

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,3 +46,45 @@ parameterless_type(x) = parameterless_type(typeof(x))
4646
parameterless_type(x::Type) = __parameterless_type(x)
4747

4848
@non_differentiable adapt_randn(::Any...)
49+
50+
# PoissonBinomial
51+
52+
# compute matrix of partial derivatives [∂P(X=j-1)/∂pᵢ]_{i=1,…,n; j=1,…,n+1}
53+
#
54+
# This uses the same dynamic programming "trick" as for the computation of the primals
55+
# in Distributions
56+
#
57+
# Reference (for the primal):
58+
#
59+
# Marlin A. Thomas & Audrey E. Taub (1982)
60+
# Calculating binomial probabilities when the trial probabilities are unequal,
61+
# Journal of Statistical Computation and Simulation, 14:2, 125-131, DOI: 10.1080/00949658208810534
62+
function poissonbinomial_partialderivatives(p)
63+
n = length(p)
64+
A = zeros(eltype(p), n, n + 1)
65+
@inbounds for j in 1:n
66+
A[j, end] = 1
67+
end
68+
@inbounds for (i, pi) in enumerate(p)
69+
qi = 1 - pi
70+
for k in (n - i + 1):n
71+
kp1 = k + 1
72+
for j in 1:(i - 1)
73+
A[j, k] = pi * A[j, k] + qi * A[j, kp1]
74+
end
75+
for j in (i+1):n
76+
A[j, k] = pi * A[j, k] + qi * A[j, kp1]
77+
end
78+
end
79+
for j in 1:(i-1)
80+
A[j, end] *= pi
81+
end
82+
for j in (i+1):n
83+
A[j, end] *= pi
84+
end
85+
end
86+
@inbounds for j in 1:n, i in 1:n
87+
A[i, j] -= A[i, j+1]
88+
end
89+
return A
90+
end

src/tracker.jl

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -261,26 +261,22 @@ end
261261
PoissonBinomial(p::TrackedArray{<:Real}; check_args=true) =
262262
TuringPoissonBinomial(p; check_args = check_args)
263263

264-
# TODO: add adjoints without ForwardDiff
265264
poissonbinomial_pdf_fft(x::TrackedArray) = track(poissonbinomial_pdf_fft, x)
266265
@grad function poissonbinomial_pdf_fft(x::TrackedArray)
267266
x_data = data(x)
268-
T = eltype(x_data)
269-
fft = poissonbinomial_pdf_fft(x_data)
270-
return fft, Δ -> begin
271-
((ForwardDiff.jacobian(poissonbinomial_pdf_fft, x_data)::Matrix{T})' * Δ,)
272-
end
267+
value = poissonbinomial_pdf_fft(x_data)
268+
A = poissonbinomial_partialderivatives(x_data)
269+
poissonbinomial_pdf_fft_pullback(Δ) = (A * Δ,)
270+
return value, poissonbinomial_pdf_fft_pullback
273271
end
274272

275273
if isdefined(Distributions, :poissonbinomial_pdf)
276274
Distributions.poissonbinomial_pdf(x::TrackedArray) = track(Distributions.poissonbinomial_pdf, x)
277275
@grad function Distributions.poissonbinomial_pdf(x::TrackedArray)
278276
x_data = data(x)
279-
T = eltype(x_data)
280277
value = Distributions.poissonbinomial_pdf(x_data)
281-
function poissonbinomial_pdf_pullback(Δ)
282-
return ((ForwardDiff.jacobian(Distributions.poissonbinomial_pdf, x_data)::Matrix{T})' * Δ,)
283-
end
278+
A = poissonbinomial_partialderivatives(x_data)
279+
poissonbinomial_pdf_pullback(Δ) = (A * Δ,)
284280
return value, poissonbinomial_pdf_pullback
285281
end
286282
end

src/zygote.jl

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,6 @@ ZygoteRules.@adjoint function Distributions.Uniform(args...)
1212
return ZygoteRules.pullback(TuringUniform, args...)
1313
end
1414

15-
## PoissonBinomial ##
16-
17-
# Zygote loads ForwardDiff, so this dummy adjoint should never be needed.
18-
# The adjoint that is used for `poissonbinomial_pdf_fft` is defined in `src/zygote_forwarddiff.jl`
19-
# ZygoteRules.@adjoint function poissonbinomial_pdf_fft(x::AbstractArray{T}) where T<:Real
20-
# error("This needs ForwardDiff. `using ForwardDiff` should fix this error.")
21-
# end
22-
2315
## Product
2416

2517
# Tests with `Kolmogorov` seem to fail otherwise?!

src/zygote_forwarddiff.jl

Lines changed: 0 additions & 20 deletions
This file was deleted.

test/Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
[deps]
2+
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
23
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
34
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
45
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
@@ -16,7 +17,8 @@ Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
1617
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
1718

1819
[compat]
19-
ChainRulesTestUtils = "0.5.3, 0.6"
20+
ChainRulesCore = "0.9"
21+
ChainRulesTestUtils = "0.6.3"
2022
Combinatorics = "1.0.2"
2123
Distributions = "0.24.3"
2224
FiniteDifferences = "0.11.3, 0.12"

0 commit comments

Comments
 (0)