Skip to content

Commit 45582ae

Browse files
committed
Update Turing source + tests
1 parent 9460909 commit 45582ae

File tree

12 files changed

+78
-131
lines changed

12 files changed

+78
-131
lines changed

test/Turing/Turing.jl

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,11 @@ using Requires, Reexport, ForwardDiff
1212
using Bijectors, StatsFuns, SpecialFunctions
1313
using Statistics, LinearAlgebra, ProgressMeter
1414
using Markdown, Libtask, MacroTools
15-
using AbstractMCMC
15+
using AbstractMCMC: sample, psample
1616
@reexport using Distributions, MCMCChains, Libtask
1717
using Tracker: Tracker
1818

1919
import Base: ~, ==, convert, hash, promote_rule, rand, getindex, setindex!
20-
import MCMCChains: AbstractChains, Chains
2120
import DynamicPPL: getspace, runmodel!
2221

2322
const PROGRESS = Ref(true)
@@ -85,9 +84,6 @@ export @model, # modelling
8584
SMC,
8685
CSMC,
8786
PG,
88-
PIMH,
89-
PMMH,
90-
IPMCMC,
9187

9288
vi, # variational inference
9389
ADVI,

test/Turing/contrib/inference/dynamichmc.jl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using AbstractMCMC: init_callback, NoCallback
1+
using AbstractMCMC: NoCallback
22

33
###
44
### DynamicHMC backend - https://github.com/tpapp/DynamicHMC.jl
@@ -50,7 +50,7 @@ end
5050

5151
getspace(::DynamicNUTS{<:Any, space}) where {space} = space
5252

53-
function sample_init!(
53+
function AbstractMCMC.sample_init!(
5454
rng::AbstractRNG,
5555
model::Model,
5656
spl::Sampler{<:DynamicNUTS},
@@ -84,11 +84,12 @@ function sample_init!(
8484
spl.state.draws = results.chain
8585
end
8686

87-
function step!(
87+
function AbstractMCMC.step!(
8888
rng::AbstractRNG,
8989
model::Model,
9090
spl::Sampler{<:DynamicNUTS},
91-
N::Integer;
91+
N::Integer,
92+
transition;
9293
kwargs...
9394
)
9495
# Pop the next draw off the vector.

test/Turing/inference/AdvancedSMC.jl

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,23 +7,21 @@
77
#######################
88

99
"""
10-
ParticleTransition{T, F<:AbstractFloat} <: AbstractTransition
10+
ParticleTransition{T, F<:AbstractFloat}
1111
1212
Fields:
1313
- `θ`: The parameters for any given sample.
1414
- `lp`: The log pdf for the sample's parameters.
1515
- `le`: The log evidence retrieved from the particle.
1616
- `weight`: The weight of the particle the sample was retrieved from.
1717
"""
18-
struct ParticleTransition{T, F<:AbstractFloat} <: AbstractTransition
18+
struct ParticleTransition{T, F<:AbstractFloat}
1919
θ::T
2020
lp::F
2121
le::F
2222
weight::F
2323
end
2424

25-
transition_type(spl::Sampler{<:ParticleInference}) = ParticleTransition
26-
2725
function additional_parameters(::Type{<:ParticleTransition})
2826
return [:lp,:le, :weight]
2927
end
@@ -76,9 +74,9 @@ function Sampler(alg::SMC, model::Model, s::Selector)
7674
return Sampler(alg, dict, s, state)
7775
end
7876

79-
function sample_init!(
77+
function AbstractMCMC.sample_init!(
8078
::AbstractRNG,
81-
model::Turing.Model,
79+
model::Model,
8280
spl::Sampler{<:SMC},
8381
N::Integer;
8482
kwargs...
@@ -105,12 +103,13 @@ function sample_init!(
105103
end
106104
end
107105

108-
function step!(
106+
function AbstractMCMC.step!(
109107
::AbstractRNG,
110-
model::Turing.Model,
108+
model::Model,
111109
spl::Sampler{<:SMC},
112-
::Integer;
113-
iteration=-1,
110+
::Integer,
111+
transition;
112+
iteration = -1,
114113
kwargs...
115114
)
116115
# check that we received a real iteration number
@@ -182,11 +181,12 @@ function Sampler(alg::PG, model::Model, s::Selector)
182181
return Sampler(alg, info, s, state)
183182
end
184183

185-
function step!(
184+
function AbstractMCMC.step!(
186185
::AbstractRNG,
187-
model::Turing.Model,
186+
model::Model,
188187
spl::Sampler{<:PG},
189-
::Integer;
188+
::Integer,
189+
transition;
190190
kwargs...
191191
)
192192
# obtain or create reference particle
@@ -232,12 +232,12 @@ function step!(
232232
return ParticleTransition(params, lp, pc.logE, 1.0)
233233
end
234234

235-
function sample_end!(
235+
function AbstractMCMC.sample_end!(
236236
::AbstractRNG,
237237
::Model,
238238
spl::Sampler{<:ParticleInference},
239239
N::Integer,
240-
ts::Vector{ParticleTransition};
240+
ts::Vector{<:ParticleTransition};
241241
kwargs...
242242
)
243243
# Set the default for resuming the sampler.

test/Turing/inference/Inference.jl

Lines changed: 17 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -12,20 +12,18 @@ using ProgressMeter, LinearAlgebra
1212
using ..Turing: PROGRESS, NamedDist, NoDist, Turing
1313
using StatsFuns: logsumexp
1414
using Random: GLOBAL_RNG, AbstractRNG, randexp
15-
using AbstractMCMC, DynamicPPL
15+
using DynamicPPL
1616
using Bijectors: _debug
1717

1818
import MCMCChains: Chains
1919
import AdvancedHMC; const AHMC = AdvancedHMC
2020
import ..Core: getchunksize, getADtype
21-
import AbstractMCMC: AbstractTransition, sample, step!, sample_init!,
22-
transitions_init, sample_end!, AbstractSampler, transition_type,
23-
callback, init_callback, AbstractCallback, psample
21+
import AbstractMCMC
22+
using AbstractMCMC: AbstractModel, AbstractCallback, AbstractSampler
2423
import DynamicPPL: tilde, dot_tilde, getspace, get_matching_type
2524

2625
export InferenceAlgorithm,
2726
Hamiltonian,
28-
AbstractGibbs,
2927
GibbsComponent,
3028
StaticHamiltonian,
3129
AdaptiveHamiltonian,
@@ -44,20 +42,8 @@ export InferenceAlgorithm,
4442
SMC,
4543
CSMC,
4644
PG,
47-
PIMH,
48-
PMMH,
49-
IPMCMC, # particle-based sampling
5045
assume,
5146
observe,
52-
step,
53-
WelfordVar,
54-
WelfordCovar,
55-
NaiveCovar,
56-
get_var,
57-
get_covar,
58-
add_sample!,
59-
reset!,
60-
step!,
6147
resume
6248

6349
#######################
@@ -95,7 +81,7 @@ end
9581
# Default Transition #
9682
######################
9783

98-
struct Transition{T, F<:AbstractFloat} <: AbstractTransition
84+
struct Transition{T, F<:AbstractFloat}
9985
θ :: T
10086
lp :: F
10187
end
@@ -147,19 +133,19 @@ function AbstractMCMC.sample(
147133
chain_type=Chains,
148134
kwargs...
149135
)
150-
return sample(rng, model, Sampler(alg, model), N; progress=PROGRESS[], chain_type=chain_type, kwargs...)
136+
return AbstractMCMC.sample(rng, model, Sampler(alg, model), N; progress=PROGRESS[], chain_type=chain_type, kwargs...)
151137
end
152138

153139
function AbstractMCMC.sample(
154-
model::AbstractModel,
140+
model::Model,
155141
alg::InferenceAlgorithm,
156142
N::Integer;
157143
resume_from=nothing,
158144
chain_type=Chains,
159145
kwargs...
160146
)
161147
if resume_from === nothing
162-
return sample(model, Sampler(alg, model), N; progress=PROGRESS[], chain_type=chain_type, kwargs...)
148+
return AbstractMCMC.sample(model, Sampler(alg, model), N; progress=PROGRESS[], chain_type=chain_type, kwargs...)
163149
else
164150
return resume(resume_from, N)
165151
end
@@ -174,7 +160,7 @@ function AbstractMCMC.psample(
174160
chain_type=Chains,
175161
kwargs...
176162
)
177-
return psample(GLOBAL_RNG, model, alg, N, n_chains; progress=false, chain_type=chain_type, kwargs...)
163+
return AbstractMCMC.psample(GLOBAL_RNG, model, alg, N, n_chains; progress=false, chain_type=chain_type, kwargs...)
178164
end
179165

180166
function AbstractMCMC.psample(
@@ -186,7 +172,7 @@ function AbstractMCMC.psample(
186172
chain_type=Chains,
187173
kwargs...
188174
)
189-
return psample(rng, model, Sampler(alg, model), N, n_chains; progress=false, chain_type=chain_type, kwargs...)
175+
return AbstractMCMC.psample(rng, model, Sampler(alg, model), N, n_chains; progress=false, chain_type=chain_type, kwargs...)
190176
end
191177

192178
function AbstractMCMC.sample_init!(
@@ -206,9 +192,9 @@ end
206192
function AbstractMCMC.sample_end!(
207193
::AbstractRNG,
208194
::Model,
209-
::AbstractSampler,
195+
::Sampler,
210196
::Integer,
211-
::Vector{<:AbstractTransition};
197+
::Vector;
212198
kwargs...
213199
)
214200
# Silence the default API function.
@@ -244,7 +230,7 @@ end
244230
# Chain making utilities #
245231
##########################
246232

247-
function _params_to_array(ts::Vector{<:AbstractTransition}, spl::Sampler)
233+
function _params_to_array(ts::Vector, spl::Sampler)
248234
names_set = Set{String}()
249235
# Extract the parameter names and values from each transition.
250236
dicts = map(ts) do t
@@ -276,7 +262,7 @@ function flatten_namedtuple(nt::NamedTuple)
276262
return [vn[1] for vn in names_vals], [vn[2] for vn in names_vals]
277263
end
278264

279-
function get_transition_extras(ts::Vector{<:AbstractTransition})
265+
function get_transition_extras(ts::Vector)
280266
# Get the extra field names from the sampler state type.
281267
# This handles things like :lp or :weight.
282268
extra_params = additional_parameters(eltype(ts))
@@ -322,8 +308,8 @@ function AbstractMCMC.bundle_samples(
322308
model::AbstractModel,
323309
spl::Sampler,
324310
N::Integer,
325-
ts::Vector{<:AbstractTransition},
326-
ct::Type{Chains};
311+
ts::Vector,
312+
::Type{Chains};
327313
discard_adapt::Bool=true,
328314
save_state=true,
329315
kwargs...
@@ -384,7 +370,7 @@ function resume(c::Chains, n_iter::Int; chain_type=Chains, kwargs...)
384370
@assert !isempty(c.info) "[Turing] cannot resume from a chain without state info"
385371

386372
# Sample a new chain.
387-
newchain = sample(
373+
newchain = AbstractMCMC.sample(
388374
c.info[:range],
389375
c.info[:model],
390376
c.info[:spl],
@@ -432,13 +418,12 @@ include("is.jl")
432418
include("AdvancedSMC.jl")
433419
include("gibbs.jl")
434420
include("../contrib/inference/sghmc.jl")
435-
include("../contrib/inference/AdvancedSMCExtensions.jl")
436421

437422
################
438423
# Typing tools #
439424
################
440425

441-
for alg in (:SMC, :PG, :PMMH, :IPMCMC, :MH, :IS, :ESS, :Gibbs)
426+
for alg in (:SMC, :PG, :MH, :IS, :ESS, :Gibbs)
442427
@eval getspace(::$alg{space}) where {space} = space
443428
end
444429
for alg in (:HMC, :HMCDA, :NUTS, :SGLD, :SGHMC)
@@ -494,7 +479,6 @@ end
494479
## Fallback functions
495480

496481
alg_str(spl::Sampler) = string(nameof(typeof(spl.alg)))
497-
transition_type(spl::Sampler) = typeof(Transition(spl))
498482

499483
# utility funcs for querying sampler information
500484
require_gradient(spl::Sampler) = false

test/Turing/inference/ess.jl

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,16 +54,23 @@ isgaussian(::NormalCanon) = true
5454
isgaussian(::AbstractMvNormal) = true
5555

5656
# always accept in the first step
57-
function step!(::AbstractRNG, model::Model, spl::Sampler{<:ESS}, ::Integer; kwargs...)
57+
function AbstractMCMC.step!(
58+
::AbstractRNG,
59+
::Model,
60+
spl::Sampler{<:ESS},
61+
::Integer,
62+
::Nothing;
63+
kwargs...
64+
)
5865
return Transition(spl)
5966
end
6067

61-
function step!(
68+
function AbstractMCMC.step!(
6269
rng::AbstractRNG,
6370
model::Model,
6471
spl::Sampler{<:ESS},
6572
::Integer,
66-
::Transition;
73+
transition;
6774
kwargs...
6875
)
6976
# obtain mean of distribution

0 commit comments

Comments
 (0)