Skip to content

Commit cb58871

Browse files
devmotioncpfiffer
andauthored
Update to AbstractMCMC 2 (#1428)
Co-authored-by: Cameron Pfiffer <[email protected]>
1 parent 5f2401b commit cb58871

File tree

23 files changed

+893
-1132
lines changed

23 files changed

+893
-1132
lines changed

Project.toml

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
name = "Turing"
22
uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
3-
version = "0.14.12"
3+
version = "0.15.0"
44

55
[deps]
66
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
77
AdvancedHMC = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d"
88
AdvancedMH = "5b7e9947-ddc0-4b3f-9b55-0d8042f74170"
99
AdvancedVI = "b5ca4192-6429-45e5-a2d9-87aec30a685c"
10+
BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66"
1011
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
1112
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
1213
DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
@@ -31,16 +32,17 @@ Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
3132
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
3233

3334
[compat]
34-
AbstractMCMC = "1"
35+
AbstractMCMC = "2.1"
3536
AdvancedHMC = "0.2.24"
36-
AdvancedMH = "0.5.1"
37+
AdvancedMH = "0.5.2"
3738
AdvancedVI = "0.1"
39+
BangBang = "0.3"
3840
Bijectors = "0.8"
3941
Distributions = "0.23.3"
4042
DistributionsAD = "0.6"
4143
DocStringExtensions = "0.8"
42-
DynamicPPL = "0.9.5"
43-
EllipticalSliceSampling = "0.2, 0.3"
44+
DynamicPPL = "0.10.0"
45+
EllipticalSliceSampling = "0.3"
4446
ForwardDiff = "0.10.3"
4547
Libtask = "0.4, 0.5"
4648
LogDensityProblems = "^0.9, 0.10"

src/Turing.jl

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,7 @@ end
7575
# Exports #
7676
###########
7777
# `using` statements for stuff to re-export
78-
using DynamicPPL: pointwise_loglikelihoods, elementwise_loglikelihoods,
79-
generated_quantities, logprior, logjoint
78+
using DynamicPPL: pointwise_loglikelihoods, generated_quantities, logprior, logjoint
8079
using StatsBase: predict
8180

8281
# Turing essentials - modelling macros and inference algorithms
@@ -137,8 +136,4 @@ export @model, # modelling
137136
generated_quantities,
138137
logprior,
139138
logjoint
140-
141-
# deprecations
142-
include("deprecations.jl")
143-
144139
end

src/contrib/inference/dynamichmc.jl

Lines changed: 39 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -41,82 +41,82 @@ function DynamicNUTS{AD}(space::Symbol...) where AD
4141
DynamicNUTS{AD, space}()
4242
end
4343

44-
mutable struct DynamicNUTSState{V<:VarInfo, D} <: AbstractSamplerState
44+
struct DynamicNUTSState{V<:AbstractVarInfo,D}
4545
vi::V
4646
draws::Vector{D}
4747
end
4848

4949
DynamicPPL.getspace(::DynamicNUTS{<:Any, space}) where {space} = space
5050

51-
function AbstractMCMC.sample_init!(
51+
DynamicPPL.initialsampler(::Sampler{<:DynamicNUTS}) = SampleFromUniform()
52+
53+
function DynamicPPL.initialstep(
5254
rng::AbstractRNG,
5355
model::Model,
5456
spl::Sampler{<:DynamicNUTS},
55-
N::Integer;
57+
vi::AbstractVarInfo;
58+
N::Int,
5659
kwargs...
5760
)
5861
# Set up lp function.
5962
function _lp(x)
60-
gradient_logp(x, spl.state.vi, model, spl)
63+
gradient_logp(x, vi, model, spl)
6164
end
6265

63-
# Set the parameters to a starting value.
64-
initialize_parameters!(spl; kwargs...)
65-
66-
model(spl.state.vi, SampleFromUniform())
67-
link!(spl.state.vi, spl)
68-
l, dl = _lp(spl.state.vi[spl])
66+
link!(vi, spl)
67+
l, dl = _lp(vi[spl])
6968
while !isfinite(l) || !isfinite(dl)
70-
model(spl.state.vi, SampleFromUniform())
71-
link!(spl.state.vi, spl)
72-
l, dl = _lp(spl.state.vi[spl])
69+
model(vi, SampleFromUniform())
70+
link!(vi, spl)
71+
l, dl = _lp(vi[spl])
7372
end
7473

75-
if spl.selector.tag == :default && !islinked(spl.state.vi, spl)
76-
link!(spl.state.vi, spl)
77-
model(spl.state.vi, spl)
74+
if spl.selector.tag == :default && !islinked(vi, spl)
75+
link!(vi, spl)
76+
model(vi, spl)
7877
end
7978

8079
results = mcmc_with_warmup(
8180
rng,
8281
FunctionLogDensity(
83-
length(spl.state.vi[spl]),
82+
length(vi[spl]),
8483
_lp
8584
),
8685
N
8786
)
87+
draws = results.chain
8888

89-
spl.state.draws = results.chain
89+
# Compute first transition and state.
90+
draw = popfirst!(draws)
91+
vi[spl] = draw
92+
transition = Transition(vi)
93+
state = DynamicNUTSState(vi, draws)
94+
95+
return transition, state
9096
end
9197

92-
function AbstractMCMC.step!(
98+
function AbstractMCMC.step(
9399
rng::AbstractRNG,
94100
model::Model,
95101
spl::Sampler{<:DynamicNUTS},
96-
N::Integer,
97-
transition;
102+
state::DynamicNUTSState;
98103
kwargs...
99104
)
105+
# Extract VarInfo object.
106+
vi = state.vi
107+
100108
# Pop the next draw off the vector.
101-
draw = popfirst!(spl.state.draws)
102-
spl.state.vi[spl] = draw
103-
return Transition(spl)
104-
end
109+
draw = popfirst!(state.draws)
110+
vi[spl] = draw
105111

106-
function Sampler(
107-
alg::DynamicNUTS,
108-
model::Model,
109-
s::Selector=Selector()
110-
)
111-
# Construct a state, using a default function.
112-
state = DynamicNUTSState(VarInfo(model), [])
112+
# Compute next transition.
113+
transition = Transition(vi)
113114

114-
# Return a new sampler.
115-
return Sampler(alg, Dict{Symbol,Any}(), s, state)
115+
return transition, state
116116
end
117117

118-
# Disable the progress logging for DynamicHMC, since it has its own progress meter.
119-
function AbstractMCMC.sample(
118+
# Disable the progress logging for DynamicHMC, since it has its own progress meter.
119+
function AbstractMCMC.sample(
120120
rng::AbstractRNG,
121121
model::AbstractModel,
122122
alg::DynamicNUTS,
@@ -131,9 +131,9 @@ end
131131
end
132132
if resume_from === nothing
133133
return AbstractMCMC.sample(rng, model, Sampler(alg, model), N;
134-
chain_type=chain_type, progress=false, kwargs...)
134+
chain_type=chain_type, progress=false, N=N, kwargs...)
135135
else
136-
return resume(resume_from, N; chain_type=chain_type, progress=false, kwargs...)
136+
return resume(resume_from, N; chain_type=chain_type, progress=false, N=N, kwargs...)
137137
end
138138
end
139139

@@ -152,5 +152,5 @@ function AbstractMCMC.sample(
152152
@warn "[HMC] Progress logging in Turing is disabled since DynamicHMC provides its own progress meter"
153153
end
154154
return AbstractMCMC.sample(rng, model, Sampler(alg, model), parallel, N, n_chains;
155-
chain_type=chain_type, progress=false, kwargs...)
155+
chain_type=chain_type, progress=false, N=N, kwargs...)
156156
end

src/core/Core.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@ import ZygoteRules
1919

2020
include("container.jl")
2121
include("ad.jl")
22-
include("deprecations.jl")
2322

2423
function __init__()
2524
@require ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" begin

src/core/deprecations.jl

Lines changed: 0 additions & 9 deletions
This file was deleted.

src/deprecations.jl

Lines changed: 0 additions & 1 deletion
This file was deleted.

0 commit comments

Comments
 (0)