Skip to content

Commit dbd6c58

Browse files
committed
Do not differentiate push!
1 parent d3ff242 commit dbd6c58

File tree

3 files changed

+15
-0
lines changed

3 files changed

+15
-0
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,14 @@ AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
88
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
99
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
1010
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
11+
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
1112

1213
[compat]
1314
AbstractMCMC = "0.4, 0.5, 1.0"
1415
Bijectors = "0.5.2, 0.6"
1516
Distributions = "0.22, 0.23"
1617
MacroTools = "0.5.1"
18+
ZygoteRules = "0.2"
1719
julia = "1"
1820

1921
[extras]

src/DynamicPPL.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ using AbstractMCMC: AbstractSampler, AbstractChains, AbstractModel
44
using Distributions
55
using Bijectors
66
using MacroTools
7+
import ZygoteRules
78

89
import Base: string,
910
Symbol,
@@ -111,5 +112,6 @@ include("varinfo.jl")
111112
include("context_implementations.jl")
112113
include("compiler.jl")
113114
include("prob_macro.jl")
115+
include("ad.jl")
114116

115117
end # module

src/ad.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
# Prevent Zygote from differentiating push!
2+
# See https://github.com/TuringLang/Turing.jl/issues/1199
3+
ZygoteRules.@adjoint function push!(
4+
vi::VarInfo,
5+
vn::VarName,
6+
r,
7+
dist::Distribution,
8+
gidset::Set{Selector}
9+
)
10+
return push!(vi, vn, r, dist, gidset), _ -> nothing
11+
end

0 commit comments

Comments
 (0)