Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ uuid = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
keywords = ["markov chain monte carlo", "probabilistic programming"]
license = "MIT"
desc = "A lightweight interface for common MCMC methods."
version = "5.7.2"
version = "5.8.0"

[deps]
BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66"
Expand Down
3 changes: 2 additions & 1 deletion docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,9 @@ Common keyword arguments for regular and parallel sampling are:
- `progress` (default: `AbstractMCMC.PROGRESS[]` which is `true` initially): toggles progress logging. See the section on [Progress logging](#progress-logging) below for more details.
- `chain_type` (default: `Any`): determines the type of the returned chain
- `callback` (default: `nothing`): if `callback !== nothing`, then
`callback(rng, model, sampler, sample, iteration)` is called after every sampling step,
`callback(rng, model, sampler, sample, iteration; kwargs...)` is called after every sampling step,
where `sample` is the most recent sample of the Markov chain and `iteration` is the current iteration
- Keyword arguments `kwargs...` are passed down from the call to `sample(...)`. If you are performing multiple-chain sampling, then `kwargs` _additionally_ contains a `chain_number` keyword argument, which runs from 1 to the number of chains. This is not present when performing single-chain sampling.
- `num_warmup` (default: `0`): number of "warm-up" steps to take before the first "regular" step,
i.e. number of times to call [`AbstractMCMC.step_warmup`](@ref) before the first call to
[`AbstractMCMC.step`](@ref).
Expand Down
6 changes: 5 additions & 1 deletion src/sample.jl
Original file line number Diff line number Diff line change
Expand Up @@ -549,6 +549,7 @@ function mcmcsample(
else
initial_state[chainidx]
end,
chain_number=chainidx,
kwargs...,
)
end
Expand Down Expand Up @@ -669,7 +670,7 @@ function mcmcsample(
Distributed.@async begin
try
function sample_chain(
seed, initial_params, initial_state, child_progress
seed, initial_params, initial_state, child_progress, chainidx
)
# Seed a new random number generator with the pre-made seed.
Random.seed!(rng, seed)
Expand All @@ -683,6 +684,7 @@ function mcmcsample(
progress=child_progress,
initial_params=initial_params,
initial_state=initial_state,
chain_number=chainidx,
kwargs...,
)

Expand All @@ -696,6 +698,7 @@ function mcmcsample(
_initial_params,
_initial_state,
child_progresses,
1:nchains;
)
finally
if progress == :overall
Expand Down Expand Up @@ -755,6 +758,7 @@ function mcmcsample(
progressname=string(progressname, " (Chain ", i, " of ", nchains, ")"),
initial_params=initial_params,
initial_state=initial_state,
chain_number=i,
kwargs...,
)
end
Expand Down
16 changes: 16 additions & 0 deletions test/sample.jl
Original file line number Diff line number Diff line change
Expand Up @@ -679,6 +679,22 @@
@test all(chain[i].b == ref_chain[i].b for i in 1:N)
end

@testset "chain_number keyword argument" begin
@testset for m in [MCMCSerial(), MCMCThreads(), MCMCDistributed()]
niters = 10
channel = Channel{Int}() do chn
# check that the `chain_number` keyword argument is passed to the callback
function callback(args...; kwargs...)
@test haskey(kwargs, :chain_number)
return put!(chn, kwargs[:chain_number])
end
chain = sample(MyModel(), MySampler(), m, niters, 4; callback=callback)
end
chain_numbers = collect(channel)
@test sort(chain_numbers) == repeat(1:4; inner=niters)
end
end
Comment on lines +683 to +696
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@joelkandiah maybe not surprising but the original naive version of this test just pushed to a vector inside the callback, and it would randomly fail with MCMCThreads (sometimes it would only pick up 39 entries rather than 40, so I guess there was some race condition...) this seems to work around it fine -- thought you might be interested


@testset "Sample vector of `NamedTuple`s" begin
chain = sample(MyModel(), MySampler(), 1_000; chain_type=Vector{NamedTuple})
# Check output type
Expand Down
Loading