Skip to content

Commit 420cf8c

Browse files
authored
Update ChainRulesCore (#172)
1 parent f73a43f commit 420cf8c

File tree

5 files changed

+31
-28
lines changed

5 files changed

+31
-28
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
2323

2424
[compat]
2525
Adapt = "2, 3"
26-
ChainRules = "0.7"
27-
ChainRulesCore = "0.9.21"
26+
ChainRules = "0.7, 0.8"
27+
ChainRulesCore = "0.9.44, 0.10"
2828
Compat = "3.6"
2929
DiffRules = "0.1, 1.0"
3030
Distributions = "0.23.3, 0.24, 0.25"

src/common.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ end
77
function turing_chol_back(A::AbstractMatrix, check)
88
C, chol_pullback = rrule(cholesky, A, Val(false), check=check)
99
function back(Δ)
10-
= Composite{typeof(C)}((U=Δ[1]))
10+
= Tangent{typeof(C)}((U=Δ[1]))
1111
∂C = chol_pullback(Ȳ)[2]
1212
(∂C, nothing)
1313
end
@@ -21,7 +21,7 @@ end
2121
function symm_turing_chol_back(A::AbstractMatrix, check, uplo)
2222
C, chol_pullback = rrule(cholesky, Symmetric(A,uplo), Val(false), check=check)
2323
function back(Δ)
24-
= Composite{typeof(C)}((U=Δ[1]))
24+
= Tangent{typeof(C)}((U=Δ[1]))
2525
∂C = chol_pullback(Ȳ)[2]
2626
(∂C, nothing, nothing)
2727
end

test/Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@ Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
1717
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
1818

1919
[compat]
20-
ChainRulesCore = "0.9"
21-
ChainRulesTestUtils = "0.6.3"
20+
ChainRulesCore = "0.9.44, 0.10"
21+
ChainRulesTestUtils = "0.6.3, 0.7"
2222
Combinatorics = "1.0.2"
2323
Distributions = "0.24.3, 0.25"
2424
FiniteDifferences = "0.11.3, 0.12"

test/ad/chainrules.jl

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,14 +34,15 @@
3434
test_frule(StatsFuns.tdistlogpdf, x, y)
3535
test_rrule(StatsFuns.tdistlogpdf, x, y)
3636

37+
# TODO: Re-enable if https://github.com/JuliaMath/SpecialFunctions.jl/pull/325 is fixed
3738
# use `BigFloat` to avoid Rmath implementation in finite differencing check
3839
# (returns `NaN` for non-integer values)
39-
n = rand(1:100)
40-
x = BigFloat(n)
41-
y = big(logistic(randn()))
42-
z = BigFloat(rand(1:n))
43-
test_frule(StatsFuns.binomlogpdf, x, y, z)
44-
test_rrule(StatsFuns.binomlogpdf, x, y, z)
40+
#n = rand(1:100)
41+
#x = BigFloat(n)
42+
#y = big(logistic(randn()))
43+
#z = BigFloat(rand(1:n))
44+
#test_frule(StatsFuns.binomlogpdf, x, y, z)
45+
#test_rrule(StatsFuns.binomlogpdf, x, y, z)
4546

4647
x = big(exp(randn()))
4748
y = BigFloat(rand(1:100))

test/ad/distributions.jl

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,8 @@
5454
DistSpec(Poisson, (0.5,), 1),
5555
DistSpec(Poisson, (0.5,), [1, 1]),
5656

57-
DistSpec(Skellam, (1.0, 2.0), -2; broken=(:Zygote,)),
58-
DistSpec(Skellam, (1.0, 2.0), [-2, -2]; broken=(:Zygote,)),
57+
DistSpec(Skellam, (1.0, 2.0), -2),
58+
DistSpec(Skellam, (1.0, 2.0), [-2, -2]),
5959

6060
DistSpec(PoissonBinomial, ([0.5, 0.5],), 0),
6161

@@ -162,8 +162,9 @@
162162

163163
DistSpec(NormalCanon, (1.0, 2.0), 0.5),
164164

165-
DistSpec(NormalInverseGaussian, (1.0, 2.0, 1.0, 1.0), 0.5; broken=(:Zygote,)),
165+
DistSpec(NormalInverseGaussian, (1.0, 2.0, 1.0, 1.0), 0.5),
166166

167+
DistSpec(Pareto, (), 1.5),
167168
DistSpec(Pareto, (1.0,), 1.5),
168169
DistSpec(Pareto, (1.0, 1.0), 1.5),
169170

@@ -213,10 +214,7 @@
213214
# Stackoverflow caused by SpecialFunctions.besselix
214215
DistSpec(VonMises, (1.0,), 1.0),
215216
DistSpec(VonMises, (1, 1), 1),
216-
217-
# Only some Zygote tests are broken and therefore this can not be checked
218-
DistSpec(Pareto, (), 1.5; broken=(:Zygote,)),
219-
217+
220218
# Some tests are broken on some Julia versions, therefore it can't be checked reliably
221219
DistSpec(PoissonBinomial, ([0.5, 0.5],), [0, 0]; broken=(:Zygote,)),
222220
]
@@ -395,17 +393,21 @@
395393
# Skellam only fails in these tests with ReverseDiff
396394
# Ref: https://github.com/TuringLang/DistributionsAD.jl/issues/126
397395
# PoissonBinomial fails with Zygote
396+
# Matrix case does not work with Skellam:
397+
# https://github.com/TuringLang/DistributionsAD.jl/pull/172#issuecomment-853721493
398398
filldist_broken = if d.f(d.θ...) isa Skellam
399-
(d.broken..., :ReverseDiff)
399+
((d.broken..., :ReverseDiff), (d.broken..., :Zygote, :ReverseDiff))
400400
elseif d.f(d.θ...) isa PoissonBinomial
401-
(d.broken..., :Zygote)
401+
((d.broken..., :Zygote), (d.broken..., :Zygote))
402402
else
403-
d.broken
403+
(d.broken, d.broken)
404404
end
405-
arraydist_broken = if d.f(d.θ...) isa PoissonBinomial
406-
(d.broken..., :Zygote)
405+
arraydist_broken = if d.f(d.θ...) isa Skellam
406+
(d.broken, (d.broken..., :Zygote))
407+
elseif d.f(d.θ...) isa PoissonBinomial
408+
((d.broken..., :Zygote), (d.broken..., :Zygote))
407409
else
408-
d.broken
410+
(d.broken, d.broken)
409411
end
410412

411413
# Create `filldist` distribution
@@ -416,7 +418,7 @@
416418
f_arraydist =...,) -> arraydist([d.f...) for _ in 1:n])
417419
d_arraydist = f_arraydist(d.θ...)
418420

419-
for sz in ((n,), (n, 2))
421+
for (i, sz) in enumerate(((n,), (n, 2)))
420422
# Matrix case doesn't work for continuous distributions for some reason
421423
# now but not too important (?!)
422424
if length(sz) == 2 && Distributions.value_support(typeof(d)) === Continuous
@@ -434,7 +436,7 @@
434436
d.θ,
435437
x,
436438
d.xtrans;
437-
broken=filldist_broken,
439+
broken=filldist_broken[i],
438440
)
439441
)
440442
test_ad(
@@ -444,7 +446,7 @@
444446
d.θ,
445447
x,
446448
d.xtrans;
447-
broken=arraydist_broken,
449+
broken=arraydist_broken[i],
448450
)
449451
)
450452
end

0 commit comments

Comments
 (0)