diff --git a/Project.toml b/Project.toml index 34b601a6..3580425a 100644 --- a/Project.toml +++ b/Project.toml @@ -3,7 +3,7 @@ uuid = "80f14c24-f653-4e6a-9b94-39d6b0f70001" keywords = ["markov chain monte carlo", "probabilistic programming"] license = "MIT" desc = "A lightweight interface for common MCMC methods." -version = "5.7.2" +version = "5.8.0" [deps] BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" diff --git a/docs/src/api.md b/docs/src/api.md index 69c43bc2..e53f2ccf 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -71,8 +71,9 @@ Common keyword arguments for regular and parallel sampling are: - `progress` (default: `AbstractMCMC.PROGRESS[]` which is `true` initially): toggles progress logging. See the section on [Progress logging](#progress-logging) below for more details. - `chain_type` (default: `Any`): determines the type of the returned chain - `callback` (default: `nothing`): if `callback !== nothing`, then - `callback(rng, model, sampler, sample, iteration)` is called after every sampling step, + `callback(rng, model, sampler, sample, iteration; kwargs...)` is called after every sampling step, where `sample` is the most recent sample of the Markov chain and `iteration` is the current iteration + - Keyword arguments `kwargs...` are passed down from the call to `sample(...)`. If you are performing multiple-chain sampling, then `kwargs` _additionally_ contains a `chain_number` keyword argument, which runs from 1 to the number of chains. This is not present when performing single-chain sampling. - `num_warmup` (default: `0`): number of "warm-up" steps to take before the first "regular" step, i.e. number of times to call [`AbstractMCMC.step_warmup`](@ref) before the first call to [`AbstractMCMC.step`](@ref). diff --git a/src/sample.jl b/src/sample.jl index c1de64f1..913332d1 100644 --- a/src/sample.jl +++ b/src/sample.jl @@ -549,6 +549,7 @@ function mcmcsample( else initial_state[chainidx] end, + chain_number=chainidx, kwargs..., ) end @@ -669,7 +670,7 @@ function mcmcsample( Distributed.@async begin try function sample_chain( - seed, initial_params, initial_state, child_progress + seed, initial_params, initial_state, child_progress, chainidx ) # Seed a new random number generator with the pre-made seed. Random.seed!(rng, seed) @@ -683,6 +684,7 @@ function mcmcsample( progress=child_progress, initial_params=initial_params, initial_state=initial_state, + chain_number=chainidx, kwargs..., ) @@ -696,6 +698,7 @@ function mcmcsample( _initial_params, _initial_state, child_progresses, + 1:nchains; ) finally if progress == :overall @@ -755,6 +758,7 @@ function mcmcsample( progressname=string(progressname, " (Chain ", i, " of ", nchains, ")"), initial_params=initial_params, initial_state=initial_state, + chain_number=i, kwargs..., ) end diff --git a/test/sample.jl b/test/sample.jl index 5af2aa74..f561f535 100644 --- a/test/sample.jl +++ b/test/sample.jl @@ -679,6 +679,22 @@ @test all(chain[i].b == ref_chain[i].b for i in 1:N) end + @testset "chain_number keyword argument" begin + @testset for m in [MCMCSerial(), MCMCThreads(), MCMCDistributed()] + niters = 10 + channel = Channel{Int}() do chn + # check that the `chain_number` keyword argument is passed to the callback + function callback(args...; kwargs...) + @test haskey(kwargs, :chain_number) + return put!(chn, kwargs[:chain_number]) + end + chain = sample(MyModel(), MySampler(), m, niters, 4; callback=callback) + end + chain_numbers = collect(channel) + @test sort(chain_numbers) == repeat(1:4; inner=niters) + end + end + @testset "Sample vector of `NamedTuple`s" begin chain = sample(MyModel(), MySampler(), 1_000; chain_type=Vector{NamedTuple}) # Check output type