Skip to content

Commit 47d26f9

Browse files
authored
Merge pull request #34 from TuringLang/cpfiffer-patch-1
Updated AbstractMCMC compat
2 parents 2a5e875 + bcf8fe6 commit 47d26f9

File tree

2 files changed

+13
-7
lines changed

2 files changed

+13
-7
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
1010
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
1111

1212
[compat]
13-
AbstractMCMC = "~0.1"
13+
AbstractMCMC = "0.3"
1414
AdvancedHMC = "0.2.20"
1515
Bijectors = "0.5.2"
1616
Distributions = "0.22"

test/Turing/inference/Inference.jl

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -144,20 +144,22 @@ function AbstractMCMC.sample(
144144
model::AbstractModel,
145145
alg::InferenceAlgorithm,
146146
N::Integer;
147+
chain_type=Chains,
147148
kwargs...
148149
)
149-
return sample(rng, model, Sampler(alg, model), N; progress=PROGRESS[], kwargs...)
150+
return sample(rng, model, Sampler(alg, model), N; progress=PROGRESS[], chain_type=chain_type, kwargs...)
150151
end
151152

152153
function AbstractMCMC.sample(
153154
model::AbstractModel,
154155
alg::InferenceAlgorithm,
155156
N::Integer;
156157
resume_from=nothing,
158+
chain_type=Chains,
157159
kwargs...
158160
)
159161
if resume_from === nothing
160-
return sample(model, Sampler(alg, model), N; progress=PROGRESS[], kwargs...)
162+
return sample(model, Sampler(alg, model), N; progress=PROGRESS[], chain_type=chain_type, kwargs...)
161163
else
162164
return resume(resume_from, N)
163165
end
@@ -169,9 +171,10 @@ function AbstractMCMC.psample(
169171
alg::InferenceAlgorithm,
170172
N::Integer,
171173
n_chains::Integer;
174+
chain_type=Chains,
172175
kwargs...
173176
)
174-
return psample(GLOBAL_RNG, model, alg, N, n_chains; progress=false, kwargs...)
177+
return psample(GLOBAL_RNG, model, alg, N, n_chains; progress=false, chain_type=chain_type, kwargs...)
175178
end
176179

177180
function AbstractMCMC.psample(
@@ -180,9 +183,10 @@ function AbstractMCMC.psample(
180183
alg::InferenceAlgorithm,
181184
N::Integer,
182185
n_chains::Integer;
186+
chain_type=Chains,
183187
kwargs...
184188
)
185-
return psample(rng, model, Sampler(alg, model), N, n_chains; progress=false, kwargs...)
189+
return psample(rng, model, Sampler(alg, model), N, n_chains; progress=false, chain_type=chain_type, kwargs...)
186190
end
187191

188192
function AbstractMCMC.sample_init!(
@@ -318,7 +322,8 @@ function AbstractMCMC.bundle_samples(
318322
model::AbstractModel,
319323
spl::Sampler,
320324
N::Integer,
321-
ts::Vector{<:AbstractTransition};
325+
ts::Vector{<:AbstractTransition},
326+
ct::Type{Chains};
322327
discard_adapt::Bool=true,
323328
save_state=true,
324329
kwargs...
@@ -375,7 +380,7 @@ function save(c::Chains, spl::AbstractSampler, model, vi, samples)
375380
return setinfo(c, merge(nt, c.info))
376381
end
377382

378-
function resume(c::Chains, n_iter::Int; kwargs...)
383+
function resume(c::Chains, n_iter::Int; chain_type=Chains, kwargs...)
379384
@assert !isempty(c.info) "[Turing] cannot resume from a chain without state info"
380385

381386
# Sample a new chain.
@@ -386,6 +391,7 @@ function resume(c::Chains, n_iter::Int; kwargs...)
386391
n_iter;
387392
resume_from=c,
388393
reuse_spl_n=n_iter,
394+
chain_type=chain_type,
389395
kwargs...
390396
)
391397

0 commit comments

Comments
 (0)