Skip to content

Commit 25afc66

Browse files
committed
Merge branch 'master' into torfjelde/step-warmup
2 parents ef68d04 + d521815 commit 25afc66

File tree

8 files changed

+141
-58
lines changed

8 files changed

+141
-58
lines changed

Project.toml

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ uuid = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
33
keywords = ["markov chain monte carlo", "probablistic programming"]
44
license = "MIT"
55
desc = "A lightweight interface for common MCMC methods."
6-
version = "4.4.1"
6+
version = "4.5.0"
77

88
[deps]
99
BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66"
@@ -24,15 +24,16 @@ ConsoleProgressMonitor = "0.1"
2424
LogDensityProblems = "2"
2525
LoggingExtras = "0.4, 0.5, 1"
2626
ProgressLogging = "0.1"
27-
StatsBase = "0.32, 0.33"
27+
StatsBase = "0.32, 0.33, 0.34"
2828
TerminalLoggers = "0.1"
2929
Transducers = "0.4.30"
3030
julia = "1.6"
3131

3232
[extras]
33+
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
3334
IJulia = "7073ff75-c697-5162-941a-fcdaad2a7d2a"
3435
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
3536
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
3637

3738
[targets]
38-
test = ["IJulia", "Statistics", "Test"]
39+
test = ["FillArrays", "IJulia", "Statistics", "Test"]

docs/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,5 @@ Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
33
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
44

55
[compat]
6-
Documenter = "0.27"
6+
Documenter = "1"
77
julia = "1"

docs/make.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@ makedocs(;
99
format=Documenter.HTML(),
1010
modules=[AbstractMCMC],
1111
pages=["Home" => "index.md", "api.md", "design.md"],
12-
strict=true,
1312
checkdocs=:exports,
1413
)
1514

docs/src/api.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ Common keyword arguments for regular and parallel sampling are:
8686
There is no "official" way for providing initial parameter values yet.
8787
However, multiple packages such as [EllipticalSliceSampling.jl](https://github.com/TuringLang/EllipticalSliceSampling.jl) and [AdvancedMH.jl](https://github.com/TuringLang/AdvancedMH.jl) support an `init_params` keyword argument for setting the initial values when sampling a single chain.
8888
To ensure that sampling multiple chains "just works" when sampling of a single chain is implemented, [we decided to support `init_params` in the default implementations of the ensemble methods](https://github.com/TuringLang/AbstractMCMC.jl/pull/94):
89-
- `init_params` (default: `nothing`): if set to `init_params !== nothing`, then the `i`th element of `init_params` is used as initial parameters of the `i`th chain. If one wants to use the same initial parameters `x` for every chain, one can specify e.g. `init_params = Iterators.repeated(x)` or `init_params = FillArrays.Fill(x, N)`.
89+
- `init_params` (default: `nothing`): if `init_params isa AbstractArray`, then the `i`th element of `init_params` is used as initial parameters of the `i`th chain. If one wants to use the same initial parameters `x` for every chain, one can specify e.g. `init_params = FillArrays.Fill(x, N)`.
9090

9191
Progress logging can be enabled and disabled globally with `AbstractMCMC.setprogress!(progress)`.
9292

src/interface.jl

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,13 +30,30 @@ be specified with the `chain_type` argument.
3030
By default, this method returns `samples`.
3131
"""
3232
function bundle_samples(
33-
samples, ::AbstractModel, ::AbstractSampler, ::Any, ::Type; kwargs...
33+
samples, model::AbstractModel, sampler::AbstractSampler, state, ::Type{T}; kwargs...
34+
) where {T}
35+
# dispatch to internal method for default implementations to fix
36+
# method ambiguity issues (see #120)
37+
return _bundle_samples(samples, model, sampler, state, T; kwargs...)
38+
end
39+
40+
function _bundle_samples(
41+
samples,
42+
@nospecialize(::AbstractModel),
43+
@nospecialize(::AbstractSampler),
44+
@nospecialize(::Any),
45+
::Type;
46+
kwargs...,
3447
)
3548
return samples
3649
end
37-
38-
function bundle_samples(
39-
samples::Vector, ::AbstractModel, ::AbstractSampler, ::Any, ::Type{Vector{T}}; kwargs...
50+
function _bundle_samples(
51+
samples::Vector,
52+
@nospecialize(::AbstractModel),
53+
@nospecialize(::AbstractSampler),
54+
@nospecialize(::Any),
55+
::Type{Vector{T}};
56+
kwargs...,
4057
) where {T}
4158
return map(samples) do sample
4259
convert(T, sample)

src/sample.jl

Lines changed: 25 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -373,8 +373,8 @@ function mcmcsample(
373373
# Create a seed for each chain using the provided random number generator.
374374
seeds = rand(rng, UInt, nchains)
375375

376-
# Ensure that initial parameters are `nothing` or indexable
377-
_init_params = _first_or_nothing(init_params, nchains)
376+
# Ensure that initial parameters are `nothing` or of the correct length
377+
check_initial_params(init_params, nchains)
378378

379379
# Set up a chains vector.
380380
chains = Vector{Any}(undef, nchains)
@@ -425,10 +425,10 @@ function mcmcsample(
425425
_sampler,
426426
N;
427427
progress=false,
428-
init_params=if _init_params === nothing
428+
init_params=if init_params === nothing
429429
nothing
430430
else
431-
_init_params[chainidx]
431+
init_params[chainidx]
432432
end,
433433
kwargs...,
434434
)
@@ -471,6 +471,9 @@ function mcmcsample(
471471
@warn "Number of chains ($nchains) is greater than number of samples per chain ($N)"
472472
end
473473

474+
# Ensure that initial parameters are `nothing` or of the correct length
475+
check_initial_params(init_params, nchains)
476+
474477
# Create a seed for each chain using the provided random number generator.
475478
seeds = rand(rng, UInt, nchains)
476479

@@ -560,6 +563,9 @@ function mcmcsample(
560563
@warn "Number of chains ($nchains) is greater than number of samples per chain ($N)"
561564
end
562565

566+
# Ensure that initial parameters are `nothing` or of the correct length
567+
check_initial_params(init_params, nchains)
568+
563569
# Create a seed for each chain using the provided random number generator.
564570
seeds = rand(rng, UInt, nchains)
565571

@@ -593,31 +599,21 @@ end
593599
tighten_eltype(x) = x
594600
tighten_eltype(x::Vector{Any}) = map(identity, x)
595601

596-
"""
597-
_first_or_nothing(x, n::Int)
598-
599-
Return the first `n` elements of collection `x`, or `nothing` if `x === nothing`.
600-
601-
If `x !== nothing`, then `x` has to contain at least `n` elements.
602-
"""
603-
function _first_or_nothing(x, n::Int)
604-
y = _first(x, n)
605-
length(y) == n || throw(
606-
ArgumentError("not enough initial parameters (expected $n, received $(length(y))"),
607-
)
608-
return y
609-
end
610-
_first_or_nothing(::Nothing, ::Int) = nothing
602+
@nospecialize check_initial_params(x, n) = throw(
603+
ArgumentError(
604+
"initial parameters must be specified as a vector of length equal to the number of chains or `nothing`",
605+
),
606+
)
611607

612-
# `first(x, n::Int)` requires Julia 1.6
613-
function _first(x, n::Int)
614-
@static if VERSION >= v"1.6.0-DEV.431"
615-
first(x, n)
616-
else
617-
if x isa AbstractVector
618-
@inbounds x[firstindex(x):min(firstindex(x) + n - 1, lastindex(x))]
619-
else
620-
collect(Iterators.take(x, n))
621-
end
608+
check_initial_params(::Nothing, n) = nothing
609+
function check_initial_params(x::AbstractArray, n)
610+
if length(x) != n
611+
throw(
612+
ArgumentError(
613+
"incorrect number of initial parameters (expected $n, received $(length(x))"
614+
),
615+
)
622616
end
617+
618+
return nothing
623619
end

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ using IJulia
44
using LogDensityProblems
55
using LoggingExtras: TeeLogger, EarlyFilteredLogger
66
using TerminalLoggers: TerminalLogger
7+
using FillArrays: FillArrays
78
using Transducers
89

910
using Distributed

test/sample.jl

Lines changed: 88 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -162,17 +162,18 @@
162162
end
163163

164164
# initial parameters
165-
init_params = [(b=randn(), a=rand()) for _ in 1:100]
165+
nchains = 100
166+
init_params = [(b=randn(), a=rand()) for _ in 1:nchains]
166167
chains = sample(
167168
MyModel(),
168169
MySampler(),
169170
MCMCThreads(),
170171
3,
171-
100;
172+
nchains;
172173
progress=false,
173174
init_params=init_params,
174175
)
175-
@test length(chains) == 100
176+
@test length(chains) == nchains
176177
@test all(
177178
chain[1].a == params.a && chain[1].b == params.b for
178179
(chain, params) in zip(chains, init_params)
@@ -184,14 +185,36 @@
184185
MySampler(),
185186
MCMCThreads(),
186187
3,
187-
100;
188+
nchains;
188189
progress=false,
189-
init_params=Iterators.repeated(init_params),
190+
init_params=FillArrays.Fill(init_params, nchains),
190191
)
191-
@test length(chains) == 100
192+
@test length(chains) == nchains
192193
@test all(
193194
chain[1].a == init_params.a && chain[1].b == init_params.b for chain in chains
194195
)
196+
197+
# Too many `init_params`
198+
@test_throws ArgumentError sample(
199+
MyModel(),
200+
MySampler(),
201+
MCMCThreads(),
202+
3,
203+
nchains;
204+
progress=false,
205+
init_params=FillArrays.Fill(init_params, nchains + 1),
206+
)
207+
208+
# Too few `init_params`
209+
@test_throws ArgumentError sample(
210+
MyModel(),
211+
MySampler(),
212+
MCMCThreads(),
213+
3,
214+
nchains;
215+
progress=false,
216+
init_params=FillArrays.Fill(init_params, nchains - 1),
217+
)
195218
end
196219

197220
@testset "Multicore sampling" begin
@@ -274,17 +297,18 @@
274297
@test all(l.level > Logging.LogLevel(-1) for l in logs)
275298

276299
# initial parameters
277-
init_params = [(a=randn(), b=rand()) for _ in 1:100]
300+
nchains = 100
301+
init_params = [(a=randn(), b=rand()) for _ in 1:nchains]
278302
chains = sample(
279303
MyModel(),
280304
MySampler(),
281305
MCMCDistributed(),
282306
3,
283-
100;
307+
nchains;
284308
progress=false,
285309
init_params=init_params,
286310
)
287-
@test length(chains) == 100
311+
@test length(chains) == nchains
288312
@test all(
289313
chain[1].a == params.a && chain[1].b == params.b for
290314
(chain, params) in zip(chains, init_params)
@@ -296,15 +320,37 @@
296320
MySampler(),
297321
MCMCDistributed(),
298322
3,
299-
100;
323+
nchains;
300324
progress=false,
301-
init_params=Iterators.repeated(init_params),
325+
init_params=FillArrays.Fill(init_params, nchains),
302326
)
303-
@test length(chains) == 100
327+
@test length(chains) == nchains
304328
@test all(
305329
chain[1].a == init_params.a && chain[1].b == init_params.b for chain in chains
306330
)
307331

332+
# Too many `init_params`
333+
@test_throws ArgumentError sample(
334+
MyModel(),
335+
MySampler(),
336+
MCMCDistributed(),
337+
3,
338+
nchains;
339+
progress=false,
340+
init_params=FillArrays.Fill(init_params, nchains + 1),
341+
)
342+
343+
# Too few `init_params`
344+
@test_throws ArgumentError sample(
345+
MyModel(),
346+
MySampler(),
347+
MCMCDistributed(),
348+
3,
349+
nchains;
350+
progress=false,
351+
init_params=FillArrays.Fill(init_params, nchains - 1),
352+
)
353+
308354
# Remove workers
309355
rmprocs(pids...)
310356
end
@@ -360,17 +406,18 @@
360406
@test all(l.level > Logging.LogLevel(-1) for l in logs)
361407

362408
# initial parameters
363-
init_params = [(a=rand(), b=randn()) for _ in 1:100]
409+
nchains = 100
410+
init_params = [(a=rand(), b=randn()) for _ in 1:nchains]
364411
chains = sample(
365412
MyModel(),
366413
MySampler(),
367414
MCMCSerial(),
368415
3,
369-
100;
416+
nchains;
370417
progress=false,
371418
init_params=init_params,
372419
)
373-
@test length(chains) == 100
420+
@test length(chains) == nchains
374421
@test all(
375422
chain[1].a == params.a && chain[1].b == params.b for
376423
(chain, params) in zip(chains, init_params)
@@ -382,14 +429,36 @@
382429
MySampler(),
383430
MCMCSerial(),
384431
3,
385-
100;
432+
nchains;
386433
progress=false,
387-
init_params=Iterators.repeated(init_params),
434+
init_params=FillArrays.Fill(init_params, nchains),
388435
)
389-
@test length(chains) == 100
436+
@test length(chains) == nchains
390437
@test all(
391438
chain[1].a == init_params.a && chain[1].b == init_params.b for chain in chains
392439
)
440+
441+
# Too many `init_params`
442+
@test_throws ArgumentError sample(
443+
MyModel(),
444+
MySampler(),
445+
MCMCSerial(),
446+
3,
447+
nchains;
448+
progress=false,
449+
init_params=FillArrays.Fill(init_params, nchains + 1),
450+
)
451+
452+
# Too few `init_params`
453+
@test_throws ArgumentError sample(
454+
MyModel(),
455+
MySampler(),
456+
MCMCSerial(),
457+
3,
458+
nchains;
459+
progress=false,
460+
init_params=FillArrays.Fill(init_params, nchains - 1),
461+
)
393462
end
394463

395464
@testset "Ensemble sampling: Reproducibility" begin
@@ -564,7 +633,7 @@
564633
@test ismissing(chain[1].a)
565634
@test mean(x.a for x in view(chain, 2:1_000)) 0.5 atol = 6e-2
566635
@test var(x.a for x in view(chain, 2:1_000)) 1 / 12 atol = 1e-2
567-
@test mean(x.b for x in chain) 0 atol = 0.1
636+
@test mean(x.b for x in chain) 0 atol = 0.11
568637
@test var(x.b for x in chain) 1 atol = 0.15
569638
end
570639

0 commit comments

Comments
 (0)