Skip to content

Commit 1886fa8

Browse files
committed
Merge branch 'master' into torfjelde/step-warmup
2 parents 25afc66 + dfb33b5 commit 1886fa8

File tree

6 files changed

+206
-71
lines changed

6 files changed

+206
-71
lines changed

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,13 @@ uuid = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
33
keywords = ["markov chain monte carlo", "probablistic programming"]
44
license = "MIT"
55
desc = "A lightweight interface for common MCMC methods."
6-
version = "4.5.0"
6+
version = "5.0.0"
77

88
[deps]
99
BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66"
1010
ConsoleProgressMonitor = "88cd18e8-d9cc-4ea6-8889-5259c0d15c8b"
1111
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
12+
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
1213
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
1314
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
1415
LoggingExtras = "e6f89c97-d47a-5376-807f-9c37f3926c36"
@@ -21,6 +22,7 @@ Transducers = "28d57a85-8fef-5791-bfe6-a80928e7c999"
2122
[compat]
2223
BangBang = "0.3.19"
2324
ConsoleProgressMonitor = "0.1"
25+
FillArrays = "1"
2426
LogDensityProblems = "2"
2527
LoggingExtras = "0.4, 0.5, 1"
2628
ProgressLogging = "0.1"

docs/src/api.md

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -79,14 +79,16 @@ Common keyword arguments for regular and parallel sampling are:
7979
- `discard_initial` (default: `num_warmup`): number of initial samples that are discarded. Note that
8080
if `discard_initial < num_warmup`, warm-up samples will also be included in the resulting samples.
8181
- `thinning` (default: `1`): factor by which to thin samples.
82+
- `initial_state` (default: `nothing`): if `initial_state !== nothing`, the first call to [`AbstractMCMC.step`](@ref)
83+
is passed `initial_state` as the `state` argument.
8284

8385
!!! info
8486
The common keyword arguments `progress`, `chain_type`, and `callback` are not supported by the iterator [`AbstractMCMC.steps`](@ref) and the transducer [`AbstractMCMC.Sample`](@ref).
8587

8688
There is no "official" way for providing initial parameter values yet.
87-
However, multiple packages such as [EllipticalSliceSampling.jl](https://github.com/TuringLang/EllipticalSliceSampling.jl) and [AdvancedMH.jl](https://github.com/TuringLang/AdvancedMH.jl) support an `init_params` keyword argument for setting the initial values when sampling a single chain.
88-
To ensure that sampling multiple chains "just works" when sampling of a single chain is implemented, [we decided to support `init_params` in the default implementations of the ensemble methods](https://github.com/TuringLang/AbstractMCMC.jl/pull/94):
89-
- `init_params` (default: `nothing`): if `init_params isa AbstractArray`, then the `i`th element of `init_params` is used as initial parameters of the `i`th chain. If one wants to use the same initial parameters `x` for every chain, one can specify e.g. `init_params = FillArrays.Fill(x, N)`.
89+
However, multiple packages such as [EllipticalSliceSampling.jl](https://github.com/TuringLang/EllipticalSliceSampling.jl) and [AdvancedMH.jl](https://github.com/TuringLang/AdvancedMH.jl) support an `initial_params` keyword argument for setting the initial values when sampling a single chain.
90+
To ensure that sampling multiple chains "just works" when sampling of a single chain is implemented, [we decided to support `initial_params` in the default implementations of the ensemble methods](https://github.com/TuringLang/AbstractMCMC.jl/pull/94):
91+
- `initial_params` (default: `nothing`): if `initial_params isa AbstractArray`, then the `i`th element of `initial_params` is used as initial parameters of the `i`th chain. If one wants to use the same initial parameters `x` for every chain, one can specify e.g. `initial_params = FillArrays.Fill(x, N)`.
9092

9193
Progress logging can be enabled and disabled globally with `AbstractMCMC.setprogress!(progress)`.
9294

src/AbstractMCMC.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ using ProgressLogging: ProgressLogging
88
using StatsBase: StatsBase
99
using TerminalLoggers: TerminalLoggers
1010
using Transducers: Transducers
11+
using FillArrays: FillArrays
1112

1213
using Distributed: Distributed
1314
using Logging: Logging

src/sample.jl

Lines changed: 82 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,7 @@ function mcmcsample(
113113
discard_initial::Int=num_warmup,
114114
thinning=1,
115115
chain_type::Type=Any,
116+
initial_state=nothing,
116117
kwargs...,
117118
)
118119
# Check the number of requested samples.
@@ -145,9 +146,17 @@ function mcmcsample(
145146

146147
# Obtain the initial sample and state.
147148
sample, state = if num_warmup > 0
148-
step_warmup(rng, model, sampler; kwargs...)
149+
if initial_state === nothing
150+
step_warmup(rng, model, sampler; kwargs...)
151+
else
152+
step_warmup(rng, model, sampler, initial_state; kwargs...)
153+
end
149154
else
150-
step(rng, model, sampler; kwargs...)
155+
if initial_state === nothing
156+
step(rng, model, sampler; kwargs...)
157+
else
158+
step(rng, model, sampler, initial_state; kwargs...)
159+
end
151160
end
152161

153162
# Update the progress bar.
@@ -253,6 +262,7 @@ function mcmcsample(
253262
num_warmup=0,
254263
discard_initial=num_warmup,
255264
thinning=1,
265+
initial_state=nothing,
256266
kwargs...,
257267
)
258268
# Determine how many samples to drop from `num_warmup` and the
@@ -267,9 +277,17 @@ function mcmcsample(
267277
@ifwithprogresslogger progress name = progressname begin
268278
# Obtain the initial sample and state.
269279
sample, state = if num_warmup > 0
270-
step_warmup(rng, model, sampler; kwargs...)
280+
if initial_state === nothing
281+
step_warmup(rng, model, sampler; kwargs...)
282+
else
283+
step_warmup(rng, model, sampler, initial_state; kwargs...)
284+
end
271285
else
272-
step(rng, model, sampler; kwargs...)
286+
if initial_state === nothing
287+
step(rng, model, sampler; kwargs...)
288+
else
289+
step(rng, model, sampler, initial_state; kwargs...)
290+
end
273291
end
274292

275293
# Discard initial samples.
@@ -349,7 +367,8 @@ function mcmcsample(
349367
nchains::Integer;
350368
progress=PROGRESS[],
351369
progressname="Sampling ($(min(nchains, Threads.nthreads())) threads)",
352-
init_params=nothing,
370+
initial_params=nothing,
371+
initial_state=nothing,
353372
kwargs...,
354373
)
355374
# Check if actually multiple threads are used.
@@ -373,8 +392,9 @@ function mcmcsample(
373392
# Create a seed for each chain using the provided random number generator.
374393
seeds = rand(rng, UInt, nchains)
375394

376-
# Ensure that initial parameters are `nothing` or of the correct length
377-
check_initial_params(init_params, nchains)
395+
# Ensure that initial parameters and states are `nothing` or of the correct length
396+
check_initial_params(initial_params, nchains)
397+
check_initial_state(initial_state, nchains)
378398

379399
# Set up a chains vector.
380400
chains = Vector{Any}(undef, nchains)
@@ -425,10 +445,15 @@ function mcmcsample(
425445
_sampler,
426446
N;
427447
progress=false,
428-
init_params=if init_params === nothing
448+
initial_params=if initial_params === nothing
429449
nothing
430450
else
431-
init_params[chainidx]
451+
initial_params[chainidx]
452+
end,
453+
initial_state=if initial_state === nothing
454+
nothing
455+
else
456+
initial_state[chainidx]
432457
end,
433458
kwargs...,
434459
)
@@ -458,7 +483,8 @@ function mcmcsample(
458483
nchains::Integer;
459484
progress=PROGRESS[],
460485
progressname="Sampling ($(Distributed.nworkers()) processes)",
461-
init_params=nothing,
486+
initial_params=nothing,
487+
initial_state=nothing,
462488
kwargs...,
463489
)
464490
# Check if actually multiple processes are used.
@@ -471,8 +497,14 @@ function mcmcsample(
471497
@warn "Number of chains ($nchains) is greater than number of samples per chain ($N)"
472498
end
473499

474-
# Ensure that initial parameters are `nothing` or of the correct length
475-
check_initial_params(init_params, nchains)
500+
# Ensure that initial parameters and states are `nothing` or of the correct length
501+
check_initial_params(initial_params, nchains)
502+
check_initial_state(initial_state, nchains)
503+
504+
_initial_params =
505+
initial_params === nothing ? FillArrays.Fill(nothing, nchains) : initial_params
506+
_initial_state =
507+
initial_state === nothing ? FillArrays.Fill(nothing, nchains) : initial_state
476508

477509
# Create a seed for each chain using the provided random number generator.
478510
seeds = rand(rng, UInt, nchains)
@@ -509,7 +541,7 @@ function mcmcsample(
509541

510542
Distributed.@async begin
511543
try
512-
function sample_chain(seed, init_params=nothing)
544+
function sample_chain(seed, initial_params, initial_state)
513545
# Seed a new random number generator with the pre-made seed.
514546
Random.seed!(rng, seed)
515547

@@ -520,7 +552,8 @@ function mcmcsample(
520552
sampler,
521553
N;
522554
progress=false,
523-
init_params=init_params,
555+
initial_params=initial_params,
556+
initial_state=initial_state,
524557
kwargs...,
525558
)
526559

@@ -530,11 +563,9 @@ function mcmcsample(
530563
# Return the new chain.
531564
return chain
532565
end
533-
chains = if init_params === nothing
534-
Distributed.pmap(sample_chain, pool, seeds)
535-
else
536-
Distributed.pmap(sample_chain, pool, seeds, init_params)
537-
end
566+
chains = Distributed.pmap(
567+
sample_chain, pool, seeds, _initial_params, _initial_state
568+
)
538569
finally
539570
# Stop updating the progress bar.
540571
progress && put!(channel, false)
@@ -555,22 +586,29 @@ function mcmcsample(
555586
N::Integer,
556587
nchains::Integer;
557588
progressname="Sampling",
558-
init_params=nothing,
589+
initial_params=nothing,
590+
initial_state=nothing,
559591
kwargs...,
560592
)
561593
# Check if the number of chains is larger than the number of samples
562594
if nchains > N
563595
@warn "Number of chains ($nchains) is greater than number of samples per chain ($N)"
564596
end
565597

566-
# Ensure that initial parameters are `nothing` or of the correct length
567-
check_initial_params(init_params, nchains)
598+
# Ensure that initial parameters and states are `nothing` or of the correct length
599+
check_initial_params(initial_params, nchains)
600+
check_initial_state(initial_state, nchains)
601+
602+
_initial_params =
603+
initial_params === nothing ? FillArrays.Fill(nothing, nchains) : initial_params
604+
_initial_state =
605+
initial_state === nothing ? FillArrays.Fill(nothing, nchains) : initial_state
568606

569607
# Create a seed for each chain using the provided random number generator.
570608
seeds = rand(rng, UInt, nchains)
571609

572610
# Sample the chains.
573-
function sample_chain(i, seed, init_params=nothing)
611+
function sample_chain(i, seed, initial_params, initial_state)
574612
# Seed a new random number generator with the pre-made seed.
575613
Random.seed!(rng, seed)
576614

@@ -581,16 +619,13 @@ function mcmcsample(
581619
sampler,
582620
N;
583621
progressname=string(progressname, " (Chain ", i, " of ", nchains, ")"),
584-
init_params=init_params,
622+
initial_params=initial_params,
623+
initial_state=initial_state,
585624
kwargs...,
586625
)
587626
end
588627

589-
chains = if init_params === nothing
590-
map(sample_chain, 1:nchains, seeds)
591-
else
592-
map(sample_chain, 1:nchains, seeds, init_params)
593-
end
628+
chains = map(sample_chain, 1:nchains, seeds, _initial_params, _initial_state)
594629

595630
# Concatenate the chains together.
596631
return chainsstack(tighten_eltype(chains))
@@ -604,7 +639,6 @@ tighten_eltype(x::Vector{Any}) = map(identity, x)
604639
"initial parameters must be specified as a vector of length equal to the number of chains or `nothing`",
605640
),
606641
)
607-
608642
check_initial_params(::Nothing, n) = nothing
609643
function check_initial_params(x::AbstractArray, n)
610644
if length(x) != n
@@ -617,3 +651,21 @@ function check_initial_params(x::AbstractArray, n)
617651

618652
return nothing
619653
end
654+
655+
@nospecialize check_initial_state(x, n) = throw(
656+
ArgumentError(
657+
"initial states must be specified as a vector of length equal to the number of chains or `nothing`",
658+
),
659+
)
660+
check_initial_state(::Nothing, n) = nothing
661+
function check_initial_state(x::AbstractArray, n)
662+
if length(x) != n
663+
throw(
664+
ArgumentError(
665+
"incorrect number of initial states (expected $n, received $(length(x))"
666+
),
667+
)
668+
end
669+
670+
return nothing
671+
end

0 commit comments

Comments
 (0)