Skip to content

Commit 6547969

Browse files
committed
Update Turing test folder (#173)
This PR updates the Turing test folder of DynamicPPL (IMO they should be synced before finishing #150). And I still think that we should get rid of this test dependency as much as possible 😄
1 parent a78db51 commit 6547969

File tree

20 files changed

+468
-650
lines changed

20 files changed

+468
-650
lines changed

test/Project.toml

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
33
AdvancedHMC = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d"
44
AdvancedMH = "5b7e9947-ddc0-4b3f-9b55-0d8042f74170"
5+
AdvancedVI = "b5ca4192-6429-45e5-a2d9-87aec30a685c"
56
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
67
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
78
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
@@ -12,12 +13,10 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
1213
Libtask = "6f1fad26-d15e-5dc8-ae53-837a1d7b8c9f"
1314
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1415
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
15-
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
1616
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
1717
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
18-
Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a"
19-
PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150"
20-
ProgressLogging = "33c8b6b6-d38a-422a-b730-caa89a2f386c"
18+
NamedArrays = "86f7a689-2022-50b4-a561-43c23ac3c673"
19+
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
2120
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
2221
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
2322
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
@@ -28,13 +27,14 @@ StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
2827
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
2928
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
3029
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
31-
UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
3230
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
31+
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
3332

3433
[compat]
3534
AbstractMCMC = "1.0.1"
3635
AdvancedHMC = "0.2.25"
3736
AdvancedMH = "0.5.1"
37+
AdvancedVI = "0.1"
3838
Bijectors = "0.8.2"
3939
Distributions = "0.23.8"
4040
DistributionsAD = "0.6.3"
@@ -45,13 +45,13 @@ Libtask = "0.4.1"
4545
LogDensityProblems = "0.10.3"
4646
MCMCChains = "4.0.4"
4747
MacroTools = "0.5.5"
48-
PDMats = "0.10"
49-
ProgressLogging = "0.1.3"
48+
NamedArrays = "0.9"
5049
Reexport = "0.2"
5150
Requires = "1.0.1"
5251
SpecialFunctions = "0.10.3"
5352
StatsBase = "0.33"
5453
StatsFuns = "0.9.5"
5554
Tracker = "0.2.11"
5655
Zygote = "0.5.4"
56+
ZygoteRules = "0.2"
5757
julia = "1.3"

test/Turing/Turing.jl

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,14 @@ using Libtask
1515
@reexport using Distributions, MCMCChains, Libtask, AbstractMCMC, Bijectors
1616
using Tracker: Tracker
1717

18+
import AdvancedVI
1819
import DynamicPPL: getspace, NoDist, NamedDist
1920

2021
const PROGRESS = Ref(true)
2122
function turnprogress(switch::Bool)
2223
@info "[Turing]: progress logging is $(switch ? "enabled" : "disabled") globally"
2324
PROGRESS[] = switch
25+
AdvancedVI.turnprogress(switch)
2426
end
2527

2628
# Random probability measures.
@@ -64,6 +66,9 @@ end
6466
###########
6567
# Exports #
6668
###########
69+
# `using` statements for stuff to re-export
70+
using DynamicPPL: elementwise_loglikelihoods, generated_quantities, logprior, logjoint
71+
using StatsBase: predict
6772

6873
# Turing essentials - modelling macros and inference algorithms
6974
export @model, # modelling
@@ -114,5 +119,11 @@ export @model, # modelling
114119
LogPoisson,
115120
NamedDist,
116121
filldist,
117-
arraydist
122+
arraydist,
123+
124+
predict,
125+
elementwise_loglikelihoods,
126+
genereated_quantities,
127+
logprior,
128+
logjoint
118129
end

test/Turing/contrib/inference/dynamichmc.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ end
127127
kwargs...
128128
)
129129
if progress
130-
@warn "[$(alg_str(alg))] Progress logging in Turing is disabled since DynamicHMC provides its own progress meter"
130+
@warn "[HMC] Progress logging in Turing is disabled since DynamicHMC provides its own progress meter"
131131
end
132132
if resume_from === nothing
133133
return AbstractMCMC.sample(rng, model, Sampler(alg, model), N;
@@ -149,7 +149,7 @@ function AbstractMCMC.sample(
149149
kwargs...
150150
)
151151
if progress
152-
@warn "[$(alg_str(alg))] Progress logging in Turing is disabled since DynamicHMC provides its own progress meter"
152+
@warn "[HMC] Progress logging in Turing is disabled since DynamicHMC provides its own progress meter"
153153
end
154154
return AbstractMCMC.sample(rng, model, Sampler(alg, model), parallel, N, n_chains;
155155
chain_type=chain_type, progress=false, kwargs...)

test/Turing/core/Core.jl

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,17 +10,18 @@ using DynamicPPL: Model, AbstractSampler, Sampler, SampleFromPrior
1010
using LinearAlgebra: copytri!
1111
using Bijectors: PDMatDistribution
1212
import Bijectors: link, invlink
13+
using AdvancedVI
1314
using StatsFuns: logsumexp, softmax
1415
@reexport using DynamicPPL
1516
using Requires
1617

18+
import ZygoteRules
19+
1720
include("container.jl")
1821
include("ad.jl")
22+
include("deprecations.jl")
23+
1924
function __init__()
20-
@require Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" begin
21-
include("compat/zygote.jl")
22-
export ZygoteAD
23-
end
2425
@require ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" begin
2526
include("compat/reversediff.jl")
2627
export ReverseDiffAD, getrdcache, setrdcache, emptyrdcache
@@ -50,6 +51,7 @@ export @model,
5051
setadsafe,
5152
ForwardDiffAD,
5253
TrackerAD,
54+
ZygoteAD,
5355
value,
5456
gradient_logp,
5557
CHUNKSIZE,

test/Turing/core/ad.jl

Lines changed: 47 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,22 +3,22 @@
33
##############################
44
const ADBACKEND = Ref(:forwarddiff)
55
setadbackend(backend_sym::Symbol) = setadbackend(Val(backend_sym))
6-
function setadbackend(::Val{:forward_diff})
7-
Base.depwarn("`Turing.setadbackend(:forward_diff)` is deprecated. Please use `Turing.setadbackend(:forwarddiff)` to use `ForwardDiff`.", :setadbackend)
8-
setadbackend(Val(:forwarddiff))
6+
function setadbackend(backend::Val)
7+
_setadbackend(backend)
8+
AdvancedVI.setadbackend(backend)
9+
Bijectors.setadbackend(backend)
910
end
10-
function setadbackend(::Val{:forwarddiff})
11+
12+
function _setadbackend(::Val{:forwarddiff})
1113
CHUNKSIZE[] == 0 && setchunksize(40)
1214
ADBACKEND[] = :forwarddiff
1315
end
14-
15-
function setadbackend(::Val{:reverse_diff})
16-
Base.depwarn("`Turing.setadbackend(:reverse_diff)` is deprecated. Please use `Turing.setadbackend(:tracker)` to use `Tracker` or `Turing.setadbackend(:reversediff)` to use `ReverseDiff`. To use `ReverseDiff`, please make sure it is loaded separately with `using ReverseDiff`.", :setadbackend)
17-
setadbackend(Val(:tracker))
18-
end
19-
function setadbackend(::Val{:tracker})
16+
function _setadbackend(::Val{:tracker})
2017
ADBACKEND[] = :tracker
2118
end
19+
function _setadbackend(::Val{:zygote})
20+
ADBACKEND[] = :zygote
21+
end
2222

2323
const ADSAFE = Ref(false)
2424
function setadsafe(switch::Bool)
@@ -42,12 +42,14 @@ getchunksize(::Type{<:Sampler{Talg}}) where Talg = getchunksize(Talg)
4242
getchunksize(::Type{SampleFromPrior}) = CHUNKSIZE[]
4343

4444
struct TrackerAD <: ADBackend end
45+
struct ZygoteAD <: ADBackend end
4546

4647
ADBackend() = ADBackend(ADBACKEND[])
4748
ADBackend(T::Symbol) = ADBackend(Val(T))
4849

4950
ADBackend(::Val{:forwarddiff}) = ForwardDiffAD{CHUNKSIZE[]}
5051
ADBackend(::Val{:tracker}) = TrackerAD
52+
ADBackend(::Val{:zygote}) = ZygoteAD
5153
ADBackend(::Val) = error("The requested AD backend is not available. Make sure to load all required packages.")
5254

5355
"""
@@ -56,13 +58,15 @@ ADBackend(::Val) = error("The requested AD backend is not available. Make sure t
5658
Find the autodifferentiation backend of the algorithm `alg`.
5759
"""
5860
getADbackend(spl::Sampler) = getADbackend(spl.alg)
61+
getADbackend(spl::SampleFromPrior) = ADBackend()()
5962

6063
"""
6164
gradient_logp(
6265
θ::AbstractVector{<:Real},
6366
vi::VarInfo,
6467
model::Model,
65-
sampler::AbstractSampler=SampleFromPrior(),
68+
sampler::AbstractSampler,
69+
ctx::DynamicPPL.AbstractContext = DynamicPPL.DefaultContext()
6670
)
6771
6872
Computes the value of the log joint of `θ` and its gradient for the model
@@ -73,9 +77,10 @@ function gradient_logp(
7377
θ::AbstractVector{<:Real},
7478
vi::VarInfo,
7579
model::Model,
76-
sampler::Sampler
80+
sampler::AbstractSampler,
81+
ctx::DynamicPPL.AbstractContext = DynamicPPL.DefaultContext()
7782
)
78-
return gradient_logp(getADbackend(sampler), θ, vi, model, sampler)
83+
return gradient_logp(getADbackend(sampler), θ, vi, model, sampler, ctx)
7984
end
8085

8186
"""
@@ -85,6 +90,7 @@ gradient_logp(
8590
vi::VarInfo,
8691
model::Model,
8792
sampler::AbstractSampler = SampleFromPrior(),
93+
ctx::DynamicPPL.AbstractContext = DynamicPPL.DefaultContext()
8894
)
8995
9096
Compute the value of the log joint of `θ` and its gradient for the model
@@ -96,12 +102,13 @@ function gradient_logp(
96102
vi::VarInfo,
97103
model::Model,
98104
sampler::AbstractSampler=SampleFromPrior(),
105+
ctx::DynamicPPL.AbstractContext = DynamicPPL.DefaultContext()
99106
)
100107
# Define function to compute log joint.
101108
logp_old = getlogp(vi)
102109
function f(θ)
103110
new_vi = VarInfo(vi, sampler, θ)
104-
model(new_vi, sampler)
111+
model(new_vi, sampler, ctx)
105112
logp = getlogp(new_vi)
106113
setlogp!(vi, ForwardDiff.value(logp))
107114
return logp
@@ -123,13 +130,14 @@ function gradient_logp(
123130
vi::VarInfo,
124131
model::Model,
125132
sampler::AbstractSampler = SampleFromPrior(),
133+
ctx::DynamicPPL.AbstractContext = DynamicPPL.DefaultContext()
126134
)
127135
T = typeof(getlogp(vi))
128136

129137
# Specify objective function.
130138
function f(θ)
131139
new_vi = VarInfo(vi, sampler, θ)
132-
model(new_vi, sampler)
140+
model(new_vi, sampler, ctx)
133141
return getlogp(new_vi)
134142
end
135143

@@ -141,6 +149,30 @@ function gradient_logp(
141149
return l, ∂l∂θ
142150
end
143151

152+
function gradient_logp(
153+
backend::ZygoteAD,
154+
θ::AbstractVector{<:Real},
155+
vi::VarInfo,
156+
model::Model,
157+
sampler::AbstractSampler = SampleFromPrior(),
158+
context::DynamicPPL.AbstractContext = DynamicPPL.DefaultContext()
159+
)
160+
T = typeof(getlogp(vi))
161+
162+
# Specify objective function.
163+
function f(θ)
164+
new_vi = VarInfo(vi, sampler, θ)
165+
model(new_vi, sampler, context)
166+
return getlogp(new_vi)
167+
end
168+
169+
# Compute forward and reverse passes.
170+
l::T, ȳ = ZygoteRules.pullback(f, θ)
171+
∂l∂θ::typeof(θ) = (1)[1]
172+
173+
return l, ∂l∂θ
174+
end
175+
144176
function verifygrad(grad::AbstractVector{<:Real})
145177
if any(isnan, grad) || any(isinf, grad)
146178
@warn("Numerical error in gradients. Rejecting current proposal...")

test/Turing/core/compat/reversediff.jl

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ function emptyrdcache end
1010

1111
getrdcache() = RDCache[]
1212
ADBackend(::Val{:reversediff}) = ReverseDiffAD{getrdcache()}
13-
function setadbackend(::Val{:reversediff})
13+
function _setadbackend(::Val{:reversediff})
1414
ADBACKEND[] = :reversediff
1515
end
1616

@@ -20,13 +20,14 @@ function gradient_logp(
2020
vi::VarInfo,
2121
model::Model,
2222
sampler::AbstractSampler = SampleFromPrior(),
23+
context::DynamicPPL.AbstractContext = DynamicPPL.DefaultContext()
2324
)
2425
T = typeof(getlogp(vi))
25-
26+
2627
# Specify objective function.
2728
function f(θ)
2829
new_vi = VarInfo(vi, sampler, θ)
29-
model(new_vi, sampler)
30+
model(new_vi, sampler, context)
3031
return getlogp(new_vi)
3132
end
3233
tp, result = taperesult(f, θ)
@@ -45,25 +46,24 @@ end
4546
@require Memoization = "6fafb56a-5788-4b4e-91ca-c0cea6611c73" @eval begin
4647
setrdcache(::Val{true}) = RDCache[] = true
4748
function emptyrdcache()
48-
for k in keys(Memoization.caches)
49-
if k[1] === typeof(memoized_taperesult)
50-
pop!(Memoization.caches, k)
51-
end
52-
end
49+
Memoization.empty_cache!(memoized_taperesult)
50+
return
5351
end
52+
5453
function gradient_logp(
5554
backend::ReverseDiffAD{true},
5655
θ::AbstractVector{<:Real},
5756
vi::VarInfo,
5857
model::Model,
5958
sampler::AbstractSampler = SampleFromPrior(),
59+
context::DynamicPPL.AbstractContext = DynamicPPL.DefaultContext()
6060
)
6161
T = typeof(getlogp(vi))
62-
62+
6363
# Specify objective function.
6464
function f(θ)
6565
new_vi = VarInfo(vi, sampler, θ)
66-
model(new_vi, sampler)
66+
model(new_vi, sampler, context)
6767
return getlogp(new_vi)
6868
end
6969
ctp, result = memoized_taperesult(f, θ)
@@ -79,15 +79,13 @@ end
7979
f::F
8080
x::Tx
8181
end
82-
function Memoization._get!(f::Union{Function, Type}, d::IdDict, keys::Tuple{Tuple{RDTapeKey}, Any})
82+
function Memoization._get!(f, d::Dict, keys::Tuple{Tuple{RDTapeKey}, Any})
8383
key = keys[1][1]
84-
return Memoization._get!(f, d, (typeof(key.f), typeof(key.x), size(key.x)))
84+
return Memoization._get!(f, d, (key.f, typeof(key.x), size(key.x), Threads.threadid()))
8585
end
8686
memoized_taperesult(f, x) = memoized_taperesult(RDTapeKey(f, x))
87-
Memoization.@memoize function memoized_taperesult(k::RDTapeKey)
87+
Memoization.@memoize Dict function memoized_taperesult(k::RDTapeKey)
8888
return compiledtape(k.f, k.x), GradientResult(k.x)
8989
end
90-
memoized_tape(f, x) = memoized_tape(RDTapeKey(f, x))
91-
Memoization.@memoize memoized_tape(k::RDTapeKey) = compiledtape(k.f, k.x)
9290
compiledtape(f, x) = compile(GradientTape(f, x))
9391
end

test/Turing/core/compat/zygote.jl

Lines changed: 0 additions & 28 deletions
This file was deleted.

0 commit comments

Comments
 (0)