Skip to content

Commit 405546f

Browse files
committed
Use ChainRulesCore and define adjoint for updategid! (#163)
This PR moves the definition of the adjoint for `updategid!` to DynamicPPL and replaces ZygoteRules with ChainRulesCore, as mentioned in a comment in TuringLang/Turing.jl#1401.
1 parent 7afd70c commit 405546f

File tree

4 files changed

+11
-10
lines changed

4 files changed

+11
-10
lines changed

Project.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,21 @@
11
name = "DynamicPPL"
22
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
3-
version = "0.9.0"
3+
version = "0.9.1"
44

55
[deps]
66
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
77
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
8+
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
89
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
910
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
1011
NaturalSort = "c020b1a1-e9b0-503a-9c33-f039bfc54a85"
1112
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
12-
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
1313

1414
[compat]
1515
AbstractMCMC = "1"
1616
Bijectors = "0.5.2, 0.6, 0.7, 0.8"
17+
ChainRulesCore = "0.9.7"
1718
Distributions = "0.23.8"
1819
MacroTools = "0.5.1"
1920
NaturalSort = "1"
20-
ZygoteRules = "0.2"
2121
julia = "1.3"

src/DynamicPPL.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@ using Distributions
55
using Bijectors
66

77
import AbstractMCMC
8+
import ChainRulesCore
89
import NaturalSort
910
import MacroTools
10-
import ZygoteRules
1111

1212
import Random
1313

src/compat/ad.jl

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
1-
# Prevent Zygote from differentiating push!
21
# See https://github.com/TuringLang/Turing.jl/issues/1199
3-
ZygoteRules.@adjoint function push!(
2+
ChainRulesCore.@non_differentiable push!(
43
vi::VarInfo,
54
vn::VarName,
65
r,
76
dist::Distribution,
87
gidset::Set{Selector}
98
)
10-
return push!(vi, vn, r, dist, gidset), _ -> nothing
11-
end
9+
10+
ChainRulesCore.@non_differentiable updategid!(
11+
vi::AbstractVarInfo,
12+
vn::VarName,
13+
spl::Sampler,
14+
)

test/Turing/core/compat/zygote.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,5 +26,3 @@ function gradient_logp(
2626

2727
return l, ∂l∂θ
2828
end
29-
30-
Zygote.@nograd DynamicPPL.updategid!

0 commit comments

Comments
 (0)