Skip to content

Commit df7d52e

Browse files
authored
Merge pull request #16 from TuringLang/csp/infer
Add `sample` method to sample until convergence
2 parents dc5a76f + 7ec741f commit df7d52e

File tree

3 files changed

+119
-4
lines changed

3 files changed

+119
-4
lines changed

src/AbstractMCMC.jl

Lines changed: 95 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -102,10 +102,10 @@ may be provided as keyword argument `callback`. It is called after every samplin
102102
function StatsBase.sample(
103103
model::AbstractModel,
104104
sampler::AbstractSampler,
105-
N::Integer;
105+
arg;
106106
kwargs...
107107
)
108-
return sample(GLOBAL_RNG, model, sampler, N; kwargs...)
108+
return sample(GLOBAL_RNG, model, sampler, arg; kwargs...)
109109
end
110110

111111
function StatsBase.sample(
@@ -247,9 +247,10 @@ end
247247

248248
"""
249249
transitions_init(transition, model, sampler, N[; kwargs...])
250+
transitions_init(transition, model, sampler[; kwargs...])
250251
251252
Generate a container for the `N` transitions of the MCMC `sampler` for the provided
252-
`model`, whose first transition is `transition`.
253+
`model`, whose first transition is `transition`. Can be called with and without a predefined size `N`.
253254
"""
254255
function transitions_init(
255256
transition,
@@ -261,11 +262,21 @@ function transitions_init(
261262
return Vector{typeof(transition)}(undef, N)
262263
end
263264

265+
function transitions_init(
266+
transition,
267+
::AbstractModel,
268+
::AbstractSampler;
269+
kwargs...
270+
)
271+
return [transition]
272+
end
273+
264274
"""
265275
transitions_save!(transitions, iteration, transition, model, sampler, N[; kwargs...])
276+
transitions_save!(transitions, iteration, transition, model, sampler[; kwargs...])
266277
267278
Save the `transition` of the MCMC `sampler` at the current `iteration` in the container of
268-
`transitions`.
279+
`transitions`. Can be called with and without a predefined size `N`.
269280
"""
270281
function transitions_save!(
271282
transitions::AbstractVector,
@@ -280,6 +291,19 @@ function transitions_save!(
280291
return
281292
end
282293

294+
295+
function transitions_save!(
296+
transitions::AbstractVector,
297+
iteration::Integer,
298+
transition,
299+
::AbstractModel,
300+
::AbstractSampler;
301+
kwargs...
302+
)
303+
push!(transitions, transition)
304+
return
305+
end
306+
283307
"""
284308
psample([rng::AbstractRNG, ]model::AbstractModel, sampler::AbstractSampler, N::Integer,
285309
nchains::Integer; kwargs...)
@@ -417,4 +441,71 @@ function steps!(
417441
return Stepper(rng, model, s, kwargs)
418442
end
419443

444+
##################################
445+
# Sample-until-convergence tools #
446+
##################################
447+
448+
"""
449+
sample([rng::AbstractRNG, ]model::AbstractModel, s::AbstractSampler, is_done::Function; kwargs...)
450+
451+
`sample` will continuously draw samples without defining a maximum number of samples until
452+
a convergence criteria defined by a user-defined function `is_done` returns `true`.
453+
454+
`is_done` is a function `f` that returns a `Bool`, with the signature
455+
456+
```julia
457+
f(rng::AbstractRNG, model::AbstractModel, s::AbstractSampler, transitions::Vector, iteration::Int; kwargs...)
458+
```
459+
460+
`is_done` should return `true` when sampling should end, and `false` otherwise.
461+
"""
462+
function StatsBase.sample(
463+
rng::AbstractRNG,
464+
model::AbstractModel,
465+
sampler::AbstractSampler,
466+
is_done;
467+
chain_type::Type=Any,
468+
progress = true,
469+
progressname = "Convergence sampling",
470+
callback = (args...; kwargs...) -> nothing,
471+
kwargs...
472+
)
473+
# Perform any necessary setup.
474+
sample_init!(rng, model, sampler, 1; kwargs...)
475+
476+
@ifwithprogresslogger progress name=progressname begin
477+
# Obtain the initial transition.
478+
transition = step!(rng, model, sampler, 1; iteration=1, kwargs...)
479+
480+
# Run callback.
481+
callback(rng, model, sampler, 1, 1, transition; kwargs...)
482+
483+
# Save the transition.
484+
transitions = transitions_init(transition, model, sampler; kwargs...)
485+
486+
# Step through the sampler until stopping.
487+
i = 2
488+
489+
while !is_done(rng, model, sampler, transitions, i; progress=progress, kwargs...)
490+
# Obtain the next transition.
491+
transition = step!(rng, model, sampler, 1, transition; iteration=i, kwargs...)
492+
493+
# Run callback.
494+
callback(rng, model, sampler, 1, i, transition; kwargs...)
495+
496+
# Save the transition.
497+
transitions_save!(transitions, i, transition, model, sampler; kwargs...)
498+
499+
# Increment iteration counter.
500+
i += 1
501+
end
502+
end
503+
504+
# Wrap up the sampler, if necessary.
505+
sample_end!(rng, model, sampler, i, transitions; kwargs...)
506+
507+
# Wrap the samples up.
508+
return bundle_samples(rng, model, sampler, i, transitions, chain_type; kwargs...)
509+
end
510+
420511
end # module AbstractMCMC

test/interface.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ struct MyTransition
66
end
77

88
struct MySampler <: AbstractMCMC.AbstractSampler end
9+
struct AnotherSampler <: AbstractMCMC.AbstractSampler end
910

1011
struct MyChain <: AbstractMCMC.AbstractChains
1112
as::Vector{Float64}
@@ -52,4 +53,20 @@ function AbstractMCMC.bundle_samples(
5253
return MyChain(as, bs)
5354
end
5455

56+
function is_done(
57+
rng::AbstractRNG,
58+
model::MyModel,
59+
s::MySampler,
60+
transitions,
61+
iteration::Int;
62+
chain_type::Type=Any,
63+
kwargs...
64+
)
65+
# Calculate the mean of x.b.
66+
bmean = mean(x.b for x in transitions)
67+
return abs(bmean) <= 0.001 || iteration >= 10_000
68+
end
69+
70+
# Set a default convergence function.
71+
AbstractMCMC.sample(model, sampler::MySampler; kwargs...) = sample(Random.GLOBAL_RNG, model, sampler, is_done; kwargs...)
5572
AbstractMCMC.chainscat(chains::Union{MyChain,Vector{<:MyChain}}...) = vcat(chains...)

test/runtests.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,4 +175,11 @@ include("interface.jl")
175175
@test Base.IteratorSize(iter) == Base.IsInfinite()
176176
@test Base.IteratorEltype(iter) == Base.EltypeUnknown()
177177
end
178+
179+
@testset "Sample without predetermined N" begin
180+
Random.seed!(1234)
181+
chain = sample(MyModel(), MySampler())
182+
bmean = mean(x.b for x in chain)
183+
@test abs(bmean) <= 0.001 && length(chain) < 10_000
184+
end
178185
end

0 commit comments

Comments
 (0)