Skip to content

Commit b006f9e

Browse files
committed
Update Turing tests from 9ccd28
1 parent 2f74d66 commit b006f9e

File tree

17 files changed

+266
-626
lines changed

17 files changed

+266
-626
lines changed

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ AdvancedHMC = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d"
2121
AdvancedMH = "5b7e9947-ddc0-4b3f-9b55-0d8042f74170"
2222
BinaryProvider = "b99e7846-7c00-51b0-8f62-c81ae34c0232"
2323
DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
24+
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
25+
EllipticalSliceSampling = "cad2338a-1db2-11e9-3401-43bc07c9ede2"
2426
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
2527
Libtask = "6f1fad26-d15e-5dc8-ae53-837a1d7b8c9f"
2628
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
@@ -42,4 +44,4 @@ Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
4244
UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
4345

4446
[targets]
45-
test = ["AdvancedHMC", "AdvancedMH", "DistributionsAD", "ForwardDiff", "Libtask", "LinearAlgebra", "LogDensityProblems", "Logging", "MCMCChains", "Markdown", "PDMats", "ProgressLogging", "Random", "Reexport", "Requires", "SpecialFunctions", "Statistics", "StatsBase", "StatsFuns", "Test", "Tracker", "UUIDs"]
47+
test = ["AdvancedHMC", "AdvancedMH", "DistributionsAD", "DocStringExtensions", "EllipticalSliceSampling", "ForwardDiff", "Libtask", "LinearAlgebra", "LogDensityProblems", "Logging", "MCMCChains", "Markdown", "PDMats", "ProgressLogging", "Random", "Reexport", "Requires", "SpecialFunctions", "Statistics", "StatsBase", "StatsFuns", "Test", "Tracker", "UUIDs"]

test/Turing/Turing.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ module Turing
99
########################################################################
1010

1111
using Requires, Reexport, ForwardDiff
12-
using Bijectors, StatsFuns, SpecialFunctions
12+
using DistributionsAD, Bijectors, StatsFuns, SpecialFunctions
1313
using Statistics, LinearAlgebra
1414
using Markdown, Libtask, MacroTools
1515
@reexport using Distributions, MCMCChains, Libtask
@@ -111,7 +111,9 @@ export @model, # modelling
111111
VecBinomialLogit,
112112
OrderedLogistic,
113113
LogPoisson,
114-
NamedDist
114+
NamedDist,
115+
filldist,
116+
arraydist
115117

116118
# Reexports
117119
using AbstractMCMC: sample, psample

test/Turing/contrib/inference/AdvancedSMCExtensions.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,7 @@ function step(model, spl::Sampler{<:IPMCMC}, VarInfos::Array{VarInfo}, is_first:
255255

256256
# Run SMC & CSMC nodes
257257
for j in 1:spl.alg.n_nodes
258-
VarInfos[j].num_produce = 0
258+
reset_num_produce!(VarInfos[j])
259259
VarInfos[j] = step(model, spl.info[:samplers][j], VarInfos[j])[1]
260260
log_zs[j] = spl.info[:samplers][j].info[:logevidence][end]
261261
end

test/Turing/core/Core.jl

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
module Core
22

3+
using DistributionsAD, Bijectors
34
using MacroTools, Libtask, ForwardDiff, Random
45
using Distributions, LinearAlgebra
56
using ..Utilities, Reexport
@@ -10,12 +11,16 @@ using DynamicPPL: Model, runmodel!,
1011
using LinearAlgebra: copytri!
1112
using Bijectors: PDMatDistribution
1213
import Bijectors: link, invlink
13-
using DistributionsAD
1414
using StatsFuns: logsumexp, softmax
1515
@reexport using DynamicPPL
16+
using Requires
1617

1718
include("container.jl")
1819
include("ad.jl")
20+
@init @require Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" begin
21+
include("compat/zygote.jl")
22+
export ZygoteAD
23+
end
1924

2025
export @model,
2126
@varname,
@@ -47,8 +52,6 @@ export @model,
4752
ADBACKEND,
4853
setchunksize,
4954
verifygrad,
50-
gradient_logp_forward,
51-
gradient_logp_reverse,
5255
@varinfo,
5356
@logpdf,
5457
@sampler,

test/Turing/core/ad.jl

Lines changed: 32 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,14 @@
11
##############################
22
# Global variables/constants #
33
##############################
4-
using Bijectors
5-
64
const ADBACKEND = Ref(:forward_diff)
7-
function setadbackend(backend_sym)
8-
@assert backend_sym == :forward_diff || backend_sym == :reverse_diff
9-
backend_sym == :forward_diff && CHUNKSIZE[] == 0 && setchunksize(40)
10-
ADBACKEND[] = backend_sym
11-
12-
Bijectors.setadbackend(backend_sym)
5+
setadbackend(backend_sym::Symbol) = setadbackend(Val(backend_sym))
6+
function setadbackend(::Val{:forward_diff})
7+
CHUNKSIZE[] == 0 && setchunksize(40)
8+
ADBACKEND[] = :forward_diff
9+
end
10+
function setadbackend(::Val{:reverse_diff})
11+
ADBACKEND[] = :reverse_diff
1312
end
1413

1514
const ADSAFE = Ref(false)
@@ -39,22 +38,23 @@ ADBackend() = ADBackend(ADBACKEND[])
3938
ADBackend(T::Symbol) = ADBackend(Val(T))
4039

4140
ADBackend(::Val{:forward_diff}) = ForwardDiffAD{CHUNKSIZE[]}
42-
ADBackend(::Val) = TrackerAD
41+
ADBackend(::Val{:reverse_diff}) = TrackerAD
42+
ADBackend(::Val) = error("The requested AD backend is not available. Make sure to load all required packages.")
4343

4444
"""
45-
getADtype(alg)
45+
getADbackend(alg)
4646
47-
Finds the autodifferentiation type of the algorithm `alg`.
47+
Find the autodifferentiation backend of the algorithm `alg`.
4848
"""
49-
getADtype(spl::Sampler) = getADtype(spl.alg)
49+
getADbackend(spl::Sampler) = getADbackend(spl.alg)
5050

5151
"""
52-
gradient_logp(
53-
θ::AbstractVector{<:Real},
54-
vi::VarInfo,
55-
model::Model,
56-
sampler::AbstractSampler=SampleFromPrior(),
57-
)
52+
gradient_logp(
53+
θ::AbstractVector{<:Real},
54+
vi::VarInfo,
55+
model::Model,
56+
sampler::AbstractSampler=SampleFromPrior(),
57+
)
5858
5959
Computes the value of the log joint of `θ` and its gradient for the model
6060
specified by `(vi, sampler, model)` using whichever automatic differentation
@@ -66,26 +66,23 @@ function gradient_logp(
6666
model::Model,
6767
sampler::Sampler
6868
)
69-
ad_type = getADtype(sampler)
70-
if ad_type <: ForwardDiffAD
71-
return gradient_logp_forward(θ, vi, model, sampler)
72-
else ad_type <: TrackerAD
73-
return gradient_logp_reverse(θ, vi, model, sampler)
74-
end
69+
return gradient_logp(getADbackend(sampler), θ, vi, model, sampler)
7570
end
7671

7772
"""
78-
gradient_logp_forward(
73+
gradient_logp(
74+
backend::ADBackend,
7975
θ::AbstractVector{<:Real},
8076
vi::VarInfo,
8177
model::Model,
82-
spl::AbstractSampler=SampleFromPrior(),
78+
sampler::AbstractSampler = SampleFromPrior(),
8379
)
8480
85-
Computes the value of the log joint of `θ` and its gradient for the model
86-
specified by `(vi, spl, model)` using forwards-mode AD from ForwardDiff.jl.
81+
Compute the value of the log joint of `θ` and its gradient for the model
82+
specified by `(vi, sampler, model)` using `backend` for AD, e.g. `ForwardDiffAD{N}()` uses `ForwardDiff.jl` with chunk size `N`, `TrackerAD()` uses `Tracker.jl` and `ZygoteAD()` uses `Zygote.jl`.
8783
"""
88-
function gradient_logp_forward(
84+
function gradient_logp(
85+
::ForwardDiffAD,
8986
θ::AbstractVector{<:Real},
9087
vi::VarInfo,
9188
model::Model,
@@ -110,23 +107,12 @@ function gradient_logp_forward(
110107

111108
return l, ∂l∂θ
112109
end
113-
114-
"""
115-
gradient_logp_reverse(
116-
θ::AbstractVector{<:Real},
117-
vi::VarInfo,
118-
model::Model,
119-
sampler::AbstractSampler=SampleFromPrior(),
120-
)
121-
122-
Computes the value of the log joint of `θ` and its gradient for the model
123-
specified by `(vi, sampler, model)` using reverse-mode AD from Tracker.jl.
124-
"""
125-
function gradient_logp_reverse(
110+
function gradient_logp(
111+
::TrackerAD,
126112
θ::AbstractVector{<:Real},
127113
vi::VarInfo,
128114
model::Model,
129-
sampler::AbstractSampler=SampleFromPrior(),
115+
sampler::AbstractSampler = SampleFromPrior(),
130116
)
131117
T = typeof(getlogp(vi))
132118

@@ -138,8 +124,9 @@ function gradient_logp_reverse(
138124

139125
# Compute forward and reverse passes.
140126
l_tracked, ȳ = Tracker.forward(f, θ)
141-
l::T, ∂l∂θ::typeof(θ) = Tracker.data(l_tracked), Tracker.data((1)[1])
142127
# Remove tracking info from variables in model (because mutable state).
128+
l::T, ∂l∂θ::typeof(θ) = Tracker.data(l_tracked), Tracker.data((1)[1])
129+
143130
return l, ∂l∂θ
144131
end
145132

@@ -153,31 +140,7 @@ function verifygrad(grad::AbstractVector{<:Real})
153140
end
154141
end
155142

156-
for F in (:link, :invlink)
157-
@eval begin
158-
function $F(
159-
dist::Dirichlet,
160-
x::Tracker.TrackedArray,
161-
::Type{Val{proj}} = Val{true}
162-
) where {proj}
163-
return Tracker.track($F, dist, x, Val{proj})
164-
end
165-
Tracker.@grad function $F(
166-
dist::Dirichlet,
167-
x::Tracker.TrackedArray,
168-
::Type{Val{proj}}
169-
) where {proj}
170-
x_data = Tracker.data(x)
171-
T = eltype(x_data)
172-
y = $F(dist, x_data, Val{proj})
173-
return y, Δ -> begin
174-
out = (ForwardDiff.jacobian(x -> $F(dist, x, Val{proj}), x_data)::Matrix{T})' * Δ
175-
return (nothing, out, nothing)
176-
end
177-
end
178-
end
179-
end
180-
143+
# These still seem necessary
181144
for F in (:link, :invlink)
182145
@eval begin
183146
$F(dist::PDMatDistribution, x::Tracker.TrackedArray) = Tracker.track($F, dist, x)

test/Turing/core/compat/zygote.jl

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
struct ZygoteAD <: ADBackend end
2+
ADBackend(::Val{:zygote}) = ZygoteAD
3+
function setadbackend(::Val{:zygote})
4+
ADBACKEND[] = :zygote
5+
end
6+
7+
function gradient_logp(
8+
backend::ZygoteAD,
9+
θ::AbstractVector{<:Real},
10+
vi::VarInfo,
11+
model::Model,
12+
sampler::AbstractSampler = SampleFromPrior(),
13+
)
14+
T = typeof(getlogp(vi))
15+
16+
# Specify objective function.
17+
function f(θ)
18+
new_vi = VarInfo(vi, sampler, θ)
19+
return getlogp(runmodel!(model, new_vi, sampler))
20+
end
21+
22+
# Compute forward and reverse passes.
23+
l::T, ȳ = Zygote.pullback(f, θ)
24+
∂l∂θ::typeof(θ) = (1)[1]
25+
26+
return l, ∂l∂θ
27+
end
28+
29+
Zygote.@nograd DynamicPPL.updategid!

test/Turing/inference/AdvancedSMC.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,7 @@ function AbstractMCMC.sample_end!(
260260
spl.state.average_logevidence = loge
261261
end
262262

263-
function assume(spl::Sampler{<:Union{PG,SMC}}, dist::Distribution, vn::VarName, ::VarInfo)
263+
function DynamicPPL.assume(spl::Sampler{<:Union{PG,SMC}}, dist::Distribution, vn::VarName, ::VarInfo)
264264
vi = current_trace().vi
265265
if vn in getspace(spl)
266266
if ~haskey(vi, vn)
@@ -283,12 +283,13 @@ function assume(spl::Sampler{<:Union{PG,SMC}}, dist::Distribution, vn::VarName,
283283
r = rand(dist)
284284
push!(vi, vn, r, dist, Selector(:invalid))
285285
end
286-
acclogp!(vi, logpdf_with_trans(dist, r, istrans(vi, vn)))
286+
lp = logpdf_with_trans(dist, r, istrans(vi, vn))
287+
acclogp!(vi, lp)
287288
end
288289
return r, 0
289290
end
290291

291-
function observe(spl::Sampler{<:Union{PG,SMC}}, dist::Distribution, value, vi)
292+
function DynamicPPL.observe(spl::Sampler{<:Union{PG,SMC}}, dist::Distribution, value, vi)
292293
produce(logpdf(dist, value))
293294
return 0
294295
end

0 commit comments

Comments
 (0)