Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,9 @@ branches:
os:
- linux
- osx
julia:
- 1.0
matrix:
include:
- julia: 1.0
- julia: 1
env: JULIA_NUM_THREADS=1
- julia: 1
Expand Down
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ desc = "A lightweight interface for common MCMC methods."
version = "1.0.1"

[deps]
BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66"
ConsoleProgressMonitor = "88cd18e8-d9cc-4ea6-8889-5259c0d15c8b"
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
Expand All @@ -16,6 +17,7 @@ StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
TerminalLoggers = "5d786b92-1e48-4d6f-9151-6b4477ca9bed"

[compat]
BangBang = "0.3.19"
ConsoleProgressMonitor = "0.1"
LoggingExtras = "0.4"
ProgressLogging = "0.1"
Expand Down
1 change: 1 addition & 0 deletions src/AbstractMCMC.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
module AbstractMCMC

import BangBang
import ConsoleProgressMonitor
import LoggingExtras
import ProgressLogging
Expand Down
49 changes: 30 additions & 19 deletions src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -88,13 +88,15 @@ function step!(
end

"""
transitions_init(transition, model, sampler, N[; kwargs...])
transitions_init(transition, model, sampler[; kwargs...])
transitions(transition, model, sampler, N[; kwargs...])
transitions(transition, model, sampler[; kwargs...])

Generate a container for the `N` transitions of the MCMC `sampler` for the provided
`model`, whose first transition is `transition`. Can be called with and without a predefined size `N`.
`model`, whose first transition is `transition`.

The method can be called with and without a predefined size `N`.
"""
function transitions_init(
function transitions(
transition,
::AbstractModel,
::AbstractSampler,
Expand All @@ -104,43 +106,52 @@ function transitions_init(
return Vector{typeof(transition)}(undef, N)
end

function transitions_init(
function transitions(
transition,
::AbstractModel,
::AbstractSampler;
kwargs...
)
return [transition]
return Vector{typeof(transition)}(undef, 1)
end

"""
transitions_save!(transitions, iteration, transition, model, sampler, N[; kwargs...])
transitions_save!(transitions, iteration, transition, model, sampler[; kwargs...])
save!!(transitions, transition, iteration, model, sampler, N[; kwargs...])
save!!(transitions, transition, iteration, model, sampler[; kwargs...])

Save the `transition` of the MCMC `sampler` at the current `iteration` in the container of
`transitions`. Can be called with and without a predefined size `N`.
`transitions`.

The function can be called with and without a predefined size `N`. By default, AbstractMCMC
uses ``setindex!`` and ``push!!`` from the Julia package
[BangBang](https://github.com/tkf/BangBang.jl) to write to and append to the container,
and widen the container type if needed.
"""
function transitions_save!(
transitions::AbstractVector,
iteration::Integer,
function save!!(
transitions,
transition,
iteration::Integer,
::AbstractModel,
::AbstractSampler,
::Integer;
kwargs...
)
transitions[iteration] = transition
return
return BangBang.setindex!!(transitions, transition, iteration)
end

function transitions_save!(
transitions::AbstractVector,
iteration::Integer,
function save!!(
transitions,
transition,
iteration::Integer,
::AbstractModel,
::AbstractSampler;
kwargs...
)
push!(transitions, transition)
return
return BangBang.push!!(transitions, transition)
end

Base.@deprecate transitions_init(transition, model::AbstractModel, sampler::AbstractSampler, N::Integer; kwargs...) transitions(transition, model, sampler, N; kwargs...) false
Base.@deprecate transitions_init(transition, model::AbstractModel, sampler::AbstractSampler; kwargs...) transitions(transition, model, sampler; kwargs...) false
Base.@deprecate transitions_save!(transitions, iteration::Integer, transition, model::AbstractModel, sampler::AbstractSampler; kwargs...) save!!(transitions, transition, iteration, model, sampler; kwargs...) false
Base.@deprecate transitions_save!(transitions, iteration::Integer, transition, model::AbstractModel, sampler::AbstractSampler, N::Integer; kwargs...) save!!(transitions, transition, iteration, model, sampler, N; kwargs...) false

11 changes: 6 additions & 5 deletions src/sample.jl
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,8 @@ function mcmcsample(
callback(rng, model, sampler, transition, 1)

# Save the transition.
transitions = transitions_init(transition, model, sampler, N; kwargs...)
transitions_save!(transitions, 1, transition, model, sampler, N; kwargs...)
transitions = AbstractMCMC.transitions(transition, model, sampler, N; kwargs...)
transitions = save!!(transitions, transition, 1, model, sampler, N; kwargs...)

# Update the progress bar.
progress && ProgressLogging.@logprogress 1/N
Expand All @@ -96,7 +96,7 @@ function mcmcsample(
callback(rng, model, sampler, transition, i)

# Save the transition.
transitions_save!(transitions, i, transition, model, sampler, N; kwargs...)
transitions = save!!(transitions, transition, i, model, sampler, N; kwargs...)

# Update the progress bar.
progress && ProgressLogging.@logprogress i/N
Expand Down Expand Up @@ -148,7 +148,8 @@ function mcmcsample(
callback(rng, model, sampler, transition, 1)

# Save the transition.
transitions = transitions_init(transition, model, sampler; kwargs...)
transitions = AbstractMCMC.transitions(transition, model, sampler; kwargs...)
transitions = save!!(transitions, transition, 1, model, sampler; kwargs...)

# Step through the sampler until stopping.
i = 2
Expand All @@ -161,7 +162,7 @@ function mcmcsample(
callback(rng, model, sampler, transition, i)

# Save the transition.
transitions_save!(transitions, i, transition, model, sampler; kwargs...)
transitions = save!!(transitions, transition, i, model, sampler; kwargs...)

# Increment iteration counter.
i += 1
Expand Down
27 changes: 11 additions & 16 deletions test/interface.jl
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
struct MyModel <: AbstractMCMC.AbstractModel end

struct MyTransition
a::Float64
b::Float64
struct MyTransition{A,B}
a::A
b::B
end

struct MySampler <: AbstractMCMC.AbstractSampler end
struct AnotherSampler <: AbstractMCMC.AbstractSampler end

struct MyChain <: AbstractMCMC.AbstractChains
as::Vector{Float64}
bs::Vector{Float64}
struct MyChain{A,B} <: AbstractMCMC.AbstractChains
as::Vector{A}
bs::Vector{B}
end

function AbstractMCMC.step!(
Expand All @@ -23,7 +23,8 @@ function AbstractMCMC.step!(
loggers = false,
kwargs...
)
a = rand(rng)
# sample `a` is missing in the first step
a = transition === nothing ? missing : rand(rng)
b = randn(rng)

loggers && push!(LOGGERS, Logging.current_logger())
Expand All @@ -37,18 +38,12 @@ function AbstractMCMC.bundle_samples(
model::MyModel,
sampler::MySampler,
N::Integer,
transitions::Vector{MyTransition},
transitions::Vector{<:MyTransition},
chain_type::Type{MyChain};
kwargs...
)
n = length(transitions)
as = Vector{Float64}(undef, n)
bs = Vector{Float64}(undef, n)
for i in 1:n
transition = transitions[i]
as[i] = transition.a
bs[i] = transition.b
end
as = [t.a for t in transitions]
bs = [t.b for t in transitions]

return MyChain(as, bs)
end
Expand Down
59 changes: 35 additions & 24 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,15 @@ include("interface.jl")
@test Logging.current_logger() === CURRENT_LOGGER

# test output type and size
@test chain isa Vector{MyTransition}
@test chain isa Vector{<:MyTransition}
@test length(chain) == N

# test some statistical properties
@test mean(x.a for x in chain) ≈ 0.5 atol=6e-2
@test var(x.a for x in chain) ≈ 1 / 12 atol=5e-3
@test mean(x.b for x in chain) ≈ 0.0 atol=5e-2
@test var(x.b for x in chain) ≈ 1 atol=6e-2
tail_chain = @view chain[2:end]
@test mean(x.a for x in tail_chain) ≈ 0.5 atol=6e-2
@test var(x.a for x in tail_chain) ≈ 1 / 12 atol=5e-3
@test mean(x.b for x in tail_chain) ≈ 0.0 atol=5e-2
@test var(x.b for x in tail_chain) ≈ 1 atol=6e-2
end

@testset "Juno" begin
Expand Down Expand Up @@ -118,26 +119,28 @@ include("interface.jl")
end

Random.seed!(1234)
chains = sample(MyModel(), MySampler(), MCMCThreads(), 10_000, 1000;
N = 10_000
chains = sample(MyModel(), MySampler(), MCMCThreads(), N, 1000;
chain_type = MyChain)

# test output type and size
@test chains isa Vector{MyChain}
@test chains isa Vector{<:MyChain}
@test length(chains) == 1000
@test all(x -> length(x.as) == length(x.bs) == 10_000, chains)
@test all(x -> length(x.as) == length(x.bs) == N, chains)

# test some statistical properties
@test all(x -> isapprox(mean(x.as), 0.5; atol=1e-2), chains)
@test all(x -> isapprox(var(x.as), 1 / 12; atol=5e-3), chains)
@test all(x -> isapprox(mean(x.bs), 0; atol=5e-2), chains)
@test all(x -> isapprox(var(x.bs), 1; atol=5e-2), chains)
@test all(x -> isapprox(mean(@view x.as[2:end]), 0.5; atol=1e-2), chains)
@test all(x -> isapprox(var(@view x.as[2:end]), 1 / 12; atol=5e-3), chains)
@test all(x -> isapprox(mean(@view x.bs[2:end]), 0; atol=5e-2), chains)
@test all(x -> isapprox(var(@view x.bs[2:end]), 1; atol=5e-2), chains)

# test reproducibility
Random.seed!(1234)
chains2 = sample(MyModel(), MySampler(), MCMCThreads(), 10_000, 1000;
chains2 = sample(MyModel(), MySampler(), MCMCThreads(), N, 1000;
chain_type = MyChain)

@test all(((x, y),) -> x.as == y.as && x.bs == y.bs, zip(chains, chains2))
@test all(c1.as[i] === c2.as[i] for (c1, c2) in zip(chains, chains2), i in 1:N)
@test all(c1.bs[i] === c2.bs[i] for (c1, c2) in zip(chains, chains2), i in 1:N)

# Suppress output.
logs, _ = collect_test_logs(; min_level=Logging.LogLevel(-1)) do
Expand Down Expand Up @@ -167,27 +170,30 @@ include("interface.jl")
include("interface.jl")
end

N = 10_000
Random.seed!(1234)
chains = sample(MyModel(), MySampler(), MCMCDistributed(), 10_000, 1000;
chains = sample(MyModel(), MySampler(), MCMCDistributed(), N, 1000;
chain_type = MyChain)

# Test output type and size.
@test chains isa Vector{MyChain}
@test chains isa Vector{<:MyChain}
@test all(c.as[1] === missing for c in chains)
@test length(chains) == 1000
@test all(x -> length(x.as) == length(x.bs) == 10_000, chains)
@test all(x -> length(x.as) == length(x.bs) == N, chains)

# Test some statistical properties.
@test all(x -> isapprox(mean(x.as), 0.5; atol=1e-2), chains)
@test all(x -> isapprox(var(x.as), 1 / 12; atol=5e-3), chains)
@test all(x -> isapprox(mean(x.bs), 0; atol=5e-2), chains)
@test all(x -> isapprox(var(x.bs), 1; atol=5e-2), chains)
@test all(x -> isapprox(mean(@view x.as[2:end]), 0.5; atol=1e-2), chains)
@test all(x -> isapprox(var(@view x.as[2:end]), 1 / 12; atol=5e-3), chains)
@test all(x -> isapprox(mean(@view x.bs[2:end]), 0; atol=5e-2), chains)
@test all(x -> isapprox(var(@view x.bs[2:end]), 1; atol=5e-2), chains)

# Test reproducibility.
Random.seed!(1234)
chains2 = sample(MyModel(), MySampler(), MCMCDistributed(), 10_000, 1000;
chains2 = sample(MyModel(), MySampler(), MCMCDistributed(), N, 1000;
chain_type = MyChain)

@test all(((x, y),) -> x.as == y.as && x.bs == y.bs, zip(chains, chains2))
@test all(c1.as[i] === c2.as[i] for (c1, c2) in zip(chains, chains2), i in 1:N)
@test all(c1.bs[i] === c2.bs[i] for (c1, c2) in zip(chains, chains2), i in 1:N)

# Suppress output.
logs, _ = collect_test_logs(; min_level=Logging.LogLevel(-1)) do
Expand All @@ -201,7 +207,7 @@ include("interface.jl")
chain1 = sample(MyModel(), MySampler(), 100; sleepy = true)
chain2 = sample(MyModel(), MySampler(), 100; sleepy = true, chain_type = MyChain)

@test chain1 isa Vector{MyTransition}
@test chain1 isa Vector{<:MyTransition}
@test chain2 isa MyChain
end

Expand All @@ -217,10 +223,15 @@ include("interface.jl")
break
end

# don't save missing values
t.a === missing && continue

push!(as, t.a)
push!(bs, t.b)
end

@test length(as) == length(bs) == 998

@test mean(as) ≈ 0.5 atol=1e-2
@test var(as) ≈ 1 / 12 atol=5e-3
@test mean(bs) ≈ 0.0 atol=5e-2
Expand Down