Skip to content

Commit aaab9bb

Browse files
committed
Fix Zygote issue with dot_observe (#236)
This PR fixes TuringLang/Turing.jl#1595. It is an alternative to #235 that does not require us to rewrite the primal less efficiently which would affect regular execution and other AD backends. Co-authored-by: David Widmann <[email protected]>
1 parent 7c8edab commit aaab9bb

File tree

4 files changed

+43
-1
lines changed

4 files changed

+43
-1
lines changed

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "DynamicPPL"
22
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
3-
version = "0.10.16"
3+
version = "0.10.17"
44

55
[deps]
66
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
@@ -11,6 +11,7 @@ Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
1111
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
1212
NaturalSort = "c020b1a1-e9b0-503a-9c33-f039bfc54a85"
1313
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
14+
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
1415

1516
[compat]
1617
AbstractMCMC = "2, 3.0"
@@ -20,4 +21,5 @@ ChainRulesCore = "0.9.7"
2021
Distributions = "0.23.8, 0.24"
2122
MacroTools = "0.5.6"
2223
NaturalSort = "1"
24+
ZygoteRules = "0.2"
2325
julia = "1.3"

src/DynamicPPL.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import AbstractMCMC
99
import ChainRulesCore
1010
import NaturalSort
1111
import MacroTools
12+
import ZygoteRules
1213

1314
import Random
1415

src/compat/ad.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,17 @@ ChainRulesCore.@non_differentiable updategid!(
1212
vn::VarName,
1313
spl::Sampler,
1414
)
15+
16+
# https://github.com/TuringLang/Turing.jl/issues/1595
17+
ZygoteRules.@adjoint function dot_observe(
18+
spl::Union{SampleFromPrior, SampleFromUniform},
19+
dists::AbstractArray{<:Distribution},
20+
value::AbstractArray,
21+
vi,
22+
)
23+
function dot_observe_fallback(spl, dists, value, vi)
24+
increment_num_produce!(vi)
25+
return sum(map(Distributions.loglikelihood, dists, value))
26+
end
27+
return ZygoteRules.pullback(__context__, dot_observe_fallback, spl, dists, value, vi)
28+
end

test/compat/ad.jl

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,4 +24,29 @@
2424

2525
test_model_ad(wishart_ad(), logp_wishart_ad)
2626
end
27+
28+
# https://github.com/TuringLang/Turing.jl/issues/1595
29+
@testset "dot_observe" begin
30+
function f_dot_observe(x)
31+
return DynamicPPL.dot_observe(SampleFromPrior(), [Normal(), Normal(-1.0, 2.0)], x, VarInfo())
32+
end
33+
function f_dot_observe_manual(x)
34+
return logpdf(Normal(), x[1]) + logpdf(Normal(-1.0, 2.0), x[2])
35+
end
36+
37+
# Manual computation of the gradient.
38+
x = randn(2)
39+
val = f_dot_observe_manual(x)
40+
grad = ForwardDiff.gradient(f_dot_observe_manual, x)
41+
42+
@test ForwardDiff.gradient(f_dot_observe, x) grad
43+
44+
y, back = Tracker.forward(f_dot_observe, x)
45+
@test Tracker.data(y) val
46+
@test Tracker.data(back(1)[1]) grad
47+
48+
y, back = Zygote.pullback(f_dot_observe, x)
49+
@test y val
50+
@test back(1)[1] grad
51+
end
2752
end

0 commit comments

Comments
 (0)