@@ -12,20 +12,18 @@ using ProgressMeter, LinearAlgebra
12
12
using .. Turing: PROGRESS, NamedDist, NoDist, Turing
13
13
using StatsFuns: logsumexp
14
14
using Random: GLOBAL_RNG, AbstractRNG, randexp
15
- using AbstractMCMC, DynamicPPL
15
+ using DynamicPPL
16
16
using Bijectors: _debug
17
17
18
18
import MCMCChains: Chains
19
19
import AdvancedHMC; const AHMC = AdvancedHMC
20
20
import .. Core: getchunksize, getADtype
21
- import AbstractMCMC: AbstractTransition, sample, step!, sample_init!,
22
- transitions_init, sample_end!, AbstractSampler, transition_type,
23
- callback, init_callback, AbstractCallback, psample
21
+ import AbstractMCMC
22
+ using AbstractMCMC: AbstractModel, AbstractCallback, AbstractSampler
24
23
import DynamicPPL: tilde, dot_tilde, getspace, get_matching_type
25
24
26
25
export InferenceAlgorithm,
27
26
Hamiltonian,
28
- AbstractGibbs,
29
27
GibbsComponent,
30
28
StaticHamiltonian,
31
29
AdaptiveHamiltonian,
@@ -44,20 +42,8 @@ export InferenceAlgorithm,
44
42
SMC,
45
43
CSMC,
46
44
PG,
47
- PIMH,
48
- PMMH,
49
- IPMCMC, # particle-based sampling
50
45
assume,
51
46
observe,
52
- step,
53
- WelfordVar,
54
- WelfordCovar,
55
- NaiveCovar,
56
- get_var,
57
- get_covar,
58
- add_sample!,
59
- reset!,
60
- step!,
61
47
resume
62
48
63
49
# ######################
95
81
# Default Transition #
96
82
# #####################
97
83
98
- struct Transition{T, F<: AbstractFloat } <: AbstractTransition
84
+ struct Transition{T, F<: AbstractFloat }
99
85
θ :: T
100
86
lp :: F
101
87
end
@@ -147,19 +133,19 @@ function AbstractMCMC.sample(
147
133
chain_type= Chains,
148
134
kwargs...
149
135
)
150
- return sample (rng, model, Sampler (alg, model), N; progress= PROGRESS[], chain_type= chain_type, kwargs... )
136
+ return AbstractMCMC . sample (rng, model, Sampler (alg, model), N; progress= PROGRESS[], chain_type= chain_type, kwargs... )
151
137
end
152
138
153
139
function AbstractMCMC. sample (
154
- model:: AbstractModel ,
140
+ model:: Model ,
155
141
alg:: InferenceAlgorithm ,
156
142
N:: Integer ;
157
143
resume_from= nothing ,
158
144
chain_type= Chains,
159
145
kwargs...
160
146
)
161
147
if resume_from === nothing
162
- return sample (model, Sampler (alg, model), N; progress= PROGRESS[], chain_type= chain_type, kwargs... )
148
+ return AbstractMCMC . sample (model, Sampler (alg, model), N; progress= PROGRESS[], chain_type= chain_type, kwargs... )
163
149
else
164
150
return resume (resume_from, N)
165
151
end
@@ -174,7 +160,7 @@ function AbstractMCMC.psample(
174
160
chain_type= Chains,
175
161
kwargs...
176
162
)
177
- return psample (GLOBAL_RNG, model, alg, N, n_chains; progress= false , chain_type= chain_type, kwargs... )
163
+ return AbstractMCMC . psample (GLOBAL_RNG, model, alg, N, n_chains; progress= false , chain_type= chain_type, kwargs... )
178
164
end
179
165
180
166
function AbstractMCMC. psample (
@@ -186,7 +172,7 @@ function AbstractMCMC.psample(
186
172
chain_type= Chains,
187
173
kwargs...
188
174
)
189
- return psample (rng, model, Sampler (alg, model), N, n_chains; progress= false , chain_type= chain_type, kwargs... )
175
+ return AbstractMCMC . psample (rng, model, Sampler (alg, model), N, n_chains; progress= false , chain_type= chain_type, kwargs... )
190
176
end
191
177
192
178
function AbstractMCMC. sample_init! (
206
192
function AbstractMCMC. sample_end! (
207
193
:: AbstractRNG ,
208
194
:: Model ,
209
- :: AbstractSampler ,
195
+ :: Sampler ,
210
196
:: Integer ,
211
- :: Vector{<:AbstractTransition} ;
197
+ :: Vector ;
212
198
kwargs...
213
199
)
214
200
# Silence the default API function.
244
230
# Chain making utilities #
245
231
# #########################
246
232
247
- function _params_to_array (ts:: Vector{<:AbstractTransition} , spl:: Sampler )
233
+ function _params_to_array (ts:: Vector , spl:: Sampler )
248
234
names_set = Set {String} ()
249
235
# Extract the parameter names and values from each transition.
250
236
dicts = map (ts) do t
@@ -276,7 +262,7 @@ function flatten_namedtuple(nt::NamedTuple)
276
262
return [vn[1 ] for vn in names_vals], [vn[2 ] for vn in names_vals]
277
263
end
278
264
279
- function get_transition_extras (ts:: Vector{<:AbstractTransition} )
265
+ function get_transition_extras (ts:: Vector )
280
266
# Get the extra field names from the sampler state type.
281
267
# This handles things like :lp or :weight.
282
268
extra_params = additional_parameters (eltype (ts))
@@ -322,8 +308,8 @@ function AbstractMCMC.bundle_samples(
322
308
model:: AbstractModel ,
323
309
spl:: Sampler ,
324
310
N:: Integer ,
325
- ts:: Vector{<:AbstractTransition} ,
326
- ct :: Type{Chains} ;
311
+ ts:: Vector ,
312
+ :: Type{Chains} ;
327
313
discard_adapt:: Bool = true ,
328
314
save_state= true ,
329
315
kwargs...
@@ -384,7 +370,7 @@ function resume(c::Chains, n_iter::Int; chain_type=Chains, kwargs...)
384
370
@assert ! isempty (c. info) " [Turing] cannot resume from a chain without state info"
385
371
386
372
# Sample a new chain.
387
- newchain = sample (
373
+ newchain = AbstractMCMC . sample (
388
374
c. info[:range ],
389
375
c. info[:model ],
390
376
c. info[:spl ],
@@ -432,13 +418,12 @@ include("is.jl")
432
418
include (" AdvancedSMC.jl" )
433
419
include (" gibbs.jl" )
434
420
include (" ../contrib/inference/sghmc.jl" )
435
- include (" ../contrib/inference/AdvancedSMCExtensions.jl" )
436
421
437
422
# ###############
438
423
# Typing tools #
439
424
# ###############
440
425
441
- for alg in (:SMC , :PG , :PMMH , :IPMCMC , : MH , :IS , :ESS , :Gibbs )
426
+ for alg in (:SMC , :PG , :MH , :IS , :ESS , :Gibbs )
442
427
@eval getspace (:: $alg{space} ) where {space} = space
443
428
end
444
429
for alg in (:HMC , :HMCDA , :NUTS , :SGLD , :SGHMC )
494
479
# # Fallback functions
495
480
496
481
alg_str (spl:: Sampler ) = string (nameof (typeof (spl. alg)))
497
- transition_type (spl:: Sampler ) = typeof (Transition (spl))
498
482
499
483
# utility funcs for querying sampler information
500
484
require_gradient (spl:: Sampler ) = false
0 commit comments