Skip to content

Commit bd41004

Browse files
authored
Merge pull request #35 from TuringLang/bangbang
Use BangBang to widen containers if needed
2 parents ad37c92 + 5c98582 commit bd41004

File tree

7 files changed

+86
-66
lines changed

7 files changed

+86
-66
lines changed

.travis.yml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,9 @@ branches:
66
os:
77
- linux
88
- osx
9-
julia:
10-
- 1.0
119
matrix:
1210
include:
11+
- julia: 1.0
1312
- julia: 1
1413
env: JULIA_NUM_THREADS=1
1514
- julia: 1

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ desc = "A lightweight interface for common MCMC methods."
66
version = "1.0.1"
77

88
[deps]
9+
BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66"
910
ConsoleProgressMonitor = "88cd18e8-d9cc-4ea6-8889-5259c0d15c8b"
1011
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
1112
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
@@ -16,6 +17,7 @@ StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
1617
TerminalLoggers = "5d786b92-1e48-4d6f-9151-6b4477ca9bed"
1718

1819
[compat]
20+
BangBang = "0.3.19"
1921
ConsoleProgressMonitor = "0.1"
2022
LoggingExtras = "0.4"
2123
ProgressLogging = "0.1"

src/AbstractMCMC.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
module AbstractMCMC
22

3+
import BangBang
34
import ConsoleProgressMonitor
45
import LoggingExtras
56
import ProgressLogging

src/interface.jl

Lines changed: 30 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -88,13 +88,15 @@ function step!(
8888
end
8989

9090
"""
91-
transitions_init(transition, model, sampler, N[; kwargs...])
92-
transitions_init(transition, model, sampler[; kwargs...])
91+
transitions(transition, model, sampler, N[; kwargs...])
92+
transitions(transition, model, sampler[; kwargs...])
9393
9494
Generate a container for the `N` transitions of the MCMC `sampler` for the provided
95-
`model`, whose first transition is `transition`. Can be called with and without a predefined size `N`.
95+
`model`, whose first transition is `transition`.
96+
97+
The method can be called with and without a predefined size `N`.
9698
"""
97-
function transitions_init(
99+
function transitions(
98100
transition,
99101
::AbstractModel,
100102
::AbstractSampler,
@@ -104,43 +106,52 @@ function transitions_init(
104106
return Vector{typeof(transition)}(undef, N)
105107
end
106108

107-
function transitions_init(
109+
function transitions(
108110
transition,
109111
::AbstractModel,
110112
::AbstractSampler;
111113
kwargs...
112114
)
113-
return [transition]
115+
return Vector{typeof(transition)}(undef, 1)
114116
end
115117

116118
"""
117-
transitions_save!(transitions, iteration, transition, model, sampler, N[; kwargs...])
118-
transitions_save!(transitions, iteration, transition, model, sampler[; kwargs...])
119+
save!!(transitions, transition, iteration, model, sampler, N[; kwargs...])
120+
save!!(transitions, transition, iteration, model, sampler[; kwargs...])
119121
120122
Save the `transition` of the MCMC `sampler` at the current `iteration` in the container of
121-
`transitions`. Can be called with and without a predefined size `N`.
123+
`transitions`.
124+
125+
The function can be called with and without a predefined size `N`. By default, AbstractMCMC
126+
uses ``setindex!`` and ``push!!`` from the Julia package
127+
[BangBang](https://github.com/tkf/BangBang.jl) to write to and append to the container,
128+
and widen the container type if needed.
122129
"""
123-
function transitions_save!(
124-
transitions::AbstractVector,
125-
iteration::Integer,
130+
function save!!(
131+
transitions,
126132
transition,
133+
iteration::Integer,
127134
::AbstractModel,
128135
::AbstractSampler,
129136
::Integer;
130137
kwargs...
131138
)
132-
transitions[iteration] = transition
133-
return
139+
return BangBang.setindex!!(transitions, transition, iteration)
134140
end
135141

136-
function transitions_save!(
137-
transitions::AbstractVector,
138-
iteration::Integer,
142+
function save!!(
143+
transitions,
139144
transition,
145+
iteration::Integer,
140146
::AbstractModel,
141147
::AbstractSampler;
142148
kwargs...
143149
)
144-
push!(transitions, transition)
145-
return
150+
return BangBang.push!!(transitions, transition)
146151
end
152+
153+
Base.@deprecate transitions_init(transition, model::AbstractModel, sampler::AbstractSampler, N::Integer; kwargs...) transitions(transition, model, sampler, N; kwargs...) false
154+
Base.@deprecate transitions_init(transition, model::AbstractModel, sampler::AbstractSampler; kwargs...) transitions(transition, model, sampler; kwargs...) false
155+
Base.@deprecate transitions_save!(transitions, iteration::Integer, transition, model::AbstractModel, sampler::AbstractSampler; kwargs...) save!!(transitions, transition, iteration, model, sampler; kwargs...) false
156+
Base.@deprecate transitions_save!(transitions, iteration::Integer, transition, model::AbstractModel, sampler::AbstractSampler, N::Integer; kwargs...) save!!(transitions, transition, iteration, model, sampler, N; kwargs...) false
157+

src/sample.jl

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,8 @@ function mcmcsample(
8181
callback(rng, model, sampler, transition, 1)
8282

8383
# Save the transition.
84-
transitions = transitions_init(transition, model, sampler, N; kwargs...)
85-
transitions_save!(transitions, 1, transition, model, sampler, N; kwargs...)
84+
transitions = AbstractMCMC.transitions(transition, model, sampler, N; kwargs...)
85+
transitions = save!!(transitions, transition, 1, model, sampler, N; kwargs...)
8686

8787
# Update the progress bar.
8888
progress && ProgressLogging.@logprogress 1/N
@@ -96,7 +96,7 @@ function mcmcsample(
9696
callback(rng, model, sampler, transition, i)
9797

9898
# Save the transition.
99-
transitions_save!(transitions, i, transition, model, sampler, N; kwargs...)
99+
transitions = save!!(transitions, transition, i, model, sampler, N; kwargs...)
100100

101101
# Update the progress bar.
102102
progress && ProgressLogging.@logprogress i/N
@@ -148,7 +148,8 @@ function mcmcsample(
148148
callback(rng, model, sampler, transition, 1)
149149

150150
# Save the transition.
151-
transitions = transitions_init(transition, model, sampler; kwargs...)
151+
transitions = AbstractMCMC.transitions(transition, model, sampler; kwargs...)
152+
transitions = save!!(transitions, transition, 1, model, sampler; kwargs...)
152153

153154
# Step through the sampler until stopping.
154155
i = 2
@@ -161,7 +162,7 @@ function mcmcsample(
161162
callback(rng, model, sampler, transition, i)
162163

163164
# Save the transition.
164-
transitions_save!(transitions, i, transition, model, sampler; kwargs...)
165+
transitions = save!!(transitions, transition, i, model, sampler; kwargs...)
165166

166167
# Increment iteration counter.
167168
i += 1

test/interface.jl

Lines changed: 11 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,16 @@
11
struct MyModel <: AbstractMCMC.AbstractModel end
22

3-
struct MyTransition
4-
a::Float64
5-
b::Float64
3+
struct MyTransition{A,B}
4+
a::A
5+
b::B
66
end
77

88
struct MySampler <: AbstractMCMC.AbstractSampler end
99
struct AnotherSampler <: AbstractMCMC.AbstractSampler end
1010

11-
struct MyChain <: AbstractMCMC.AbstractChains
12-
as::Vector{Float64}
13-
bs::Vector{Float64}
11+
struct MyChain{A,B} <: AbstractMCMC.AbstractChains
12+
as::Vector{A}
13+
bs::Vector{B}
1414
end
1515

1616
function AbstractMCMC.step!(
@@ -23,7 +23,8 @@ function AbstractMCMC.step!(
2323
loggers = false,
2424
kwargs...
2525
)
26-
a = rand(rng)
26+
# sample `a` is missing in the first step
27+
a = transition === nothing ? missing : rand(rng)
2728
b = randn(rng)
2829

2930
loggers && push!(LOGGERS, Logging.current_logger())
@@ -37,18 +38,12 @@ function AbstractMCMC.bundle_samples(
3738
model::MyModel,
3839
sampler::MySampler,
3940
N::Integer,
40-
transitions::Vector{MyTransition},
41+
transitions::Vector{<:MyTransition},
4142
chain_type::Type{MyChain};
4243
kwargs...
4344
)
44-
n = length(transitions)
45-
as = Vector{Float64}(undef, n)
46-
bs = Vector{Float64}(undef, n)
47-
for i in 1:n
48-
transition = transitions[i]
49-
as[i] = transition.a
50-
bs[i] = transition.b
51-
end
45+
as = [t.a for t in transitions]
46+
bs = [t.b for t in transitions]
5247

5348
return MyChain(as, bs)
5449
end

test/runtests.jl

Lines changed: 35 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -35,14 +35,15 @@ include("interface.jl")
3535
@test Logging.current_logger() === CURRENT_LOGGER
3636

3737
# test output type and size
38-
@test chain isa Vector{MyTransition}
38+
@test chain isa Vector{<:MyTransition}
3939
@test length(chain) == N
4040

4141
# test some statistical properties
42-
@test mean(x.a for x in chain) 0.5 atol=6e-2
43-
@test var(x.a for x in chain) 1 / 12 atol=5e-3
44-
@test mean(x.b for x in chain) 0.0 atol=5e-2
45-
@test var(x.b for x in chain) 1 atol=6e-2
42+
tail_chain = @view chain[2:end]
43+
@test mean(x.a for x in tail_chain) 0.5 atol=6e-2
44+
@test var(x.a for x in tail_chain) 1 / 12 atol=5e-3
45+
@test mean(x.b for x in tail_chain) 0.0 atol=5e-2
46+
@test var(x.b for x in tail_chain) 1 atol=6e-2
4647
end
4748

4849
@testset "Juno" begin
@@ -118,26 +119,28 @@ include("interface.jl")
118119
end
119120

120121
Random.seed!(1234)
121-
chains = sample(MyModel(), MySampler(), MCMCThreads(), 10_000, 1000;
122+
N = 10_000
123+
chains = sample(MyModel(), MySampler(), MCMCThreads(), N, 1000;
122124
chain_type = MyChain)
123125

124126
# test output type and size
125-
@test chains isa Vector{MyChain}
127+
@test chains isa Vector{<:MyChain}
126128
@test length(chains) == 1000
127-
@test all(x -> length(x.as) == length(x.bs) == 10_000, chains)
129+
@test all(x -> length(x.as) == length(x.bs) == N, chains)
128130

129131
# test some statistical properties
130-
@test all(x -> isapprox(mean(x.as), 0.5; atol=1e-2), chains)
131-
@test all(x -> isapprox(var(x.as), 1 / 12; atol=5e-3), chains)
132-
@test all(x -> isapprox(mean(x.bs), 0; atol=5e-2), chains)
133-
@test all(x -> isapprox(var(x.bs), 1; atol=5e-2), chains)
132+
@test all(x -> isapprox(mean(@view x.as[2:end]), 0.5; atol=1e-2), chains)
133+
@test all(x -> isapprox(var(@view x.as[2:end]), 1 / 12; atol=5e-3), chains)
134+
@test all(x -> isapprox(mean(@view x.bs[2:end]), 0; atol=5e-2), chains)
135+
@test all(x -> isapprox(var(@view x.bs[2:end]), 1; atol=5e-2), chains)
134136

135137
# test reproducibility
136138
Random.seed!(1234)
137-
chains2 = sample(MyModel(), MySampler(), MCMCThreads(), 10_000, 1000;
139+
chains2 = sample(MyModel(), MySampler(), MCMCThreads(), N, 1000;
138140
chain_type = MyChain)
139141

140-
@test all(((x, y),) -> x.as == y.as && x.bs == y.bs, zip(chains, chains2))
142+
@test all(c1.as[i] === c2.as[i] for (c1, c2) in zip(chains, chains2), i in 1:N)
143+
@test all(c1.bs[i] === c2.bs[i] for (c1, c2) in zip(chains, chains2), i in 1:N)
141144

142145
# Unexpected order of arguments.
143146
str = "Number of chains (10) is greater than number of samples per chain (5)"
@@ -173,27 +176,30 @@ include("interface.jl")
173176
include("interface.jl")
174177
end
175178

179+
N = 10_000
176180
Random.seed!(1234)
177-
chains = sample(MyModel(), MySampler(), MCMCDistributed(), 10_000, 1000;
181+
chains = sample(MyModel(), MySampler(), MCMCDistributed(), N, 1000;
178182
chain_type = MyChain)
179183

180184
# Test output type and size.
181-
@test chains isa Vector{MyChain}
185+
@test chains isa Vector{<:MyChain}
186+
@test all(c.as[1] === missing for c in chains)
182187
@test length(chains) == 1000
183-
@test all(x -> length(x.as) == length(x.bs) == 10_000, chains)
188+
@test all(x -> length(x.as) == length(x.bs) == N, chains)
184189

185190
# Test some statistical properties.
186-
@test all(x -> isapprox(mean(x.as), 0.5; atol=1e-2), chains)
187-
@test all(x -> isapprox(var(x.as), 1 / 12; atol=5e-3), chains)
188-
@test all(x -> isapprox(mean(x.bs), 0; atol=5e-2), chains)
189-
@test all(x -> isapprox(var(x.bs), 1; atol=5e-2), chains)
191+
@test all(x -> isapprox(mean(@view x.as[2:end]), 0.5; atol=1e-2), chains)
192+
@test all(x -> isapprox(var(@view x.as[2:end]), 1 / 12; atol=5e-3), chains)
193+
@test all(x -> isapprox(mean(@view x.bs[2:end]), 0; atol=5e-2), chains)
194+
@test all(x -> isapprox(var(@view x.bs[2:end]), 1; atol=5e-2), chains)
190195

191196
# Test reproducibility.
192197
Random.seed!(1234)
193-
chains2 = sample(MyModel(), MySampler(), MCMCDistributed(), 10_000, 1000;
198+
chains2 = sample(MyModel(), MySampler(), MCMCDistributed(), N, 1000;
194199
chain_type = MyChain)
195200

196-
@test all(((x, y),) -> x.as == y.as && x.bs == y.bs, zip(chains, chains2))
201+
@test all(c1.as[i] === c2.as[i] for (c1, c2) in zip(chains, chains2), i in 1:N)
202+
@test all(c1.bs[i] === c2.bs[i] for (c1, c2) in zip(chains, chains2), i in 1:N)
197203

198204
# Unexpected order of arguments.
199205
str = "Number of chains (10) is greater than number of samples per chain (5)"
@@ -213,7 +219,7 @@ include("interface.jl")
213219
chain1 = sample(MyModel(), MySampler(), 100; sleepy = true)
214220
chain2 = sample(MyModel(), MySampler(), 100; sleepy = true, chain_type = MyChain)
215221

216-
@test chain1 isa Vector{MyTransition}
222+
@test chain1 isa Vector{<:MyTransition}
217223
@test chain2 isa MyChain
218224
end
219225

@@ -229,10 +235,15 @@ include("interface.jl")
229235
break
230236
end
231237

238+
# don't save missing values
239+
t.a === missing && continue
240+
232241
push!(as, t.a)
233242
push!(bs, t.b)
234243
end
235244

245+
@test length(as) == length(bs) == 998
246+
236247
@test mean(as) 0.5 atol=1e-2
237248
@test var(as) 1 / 12 atol=5e-3
238249
@test mean(bs) 0.0 atol=5e-2

0 commit comments

Comments
 (0)