Skip to content

Commit d8f5ab9

Browse files
committed
Add chain_number keyword argument when performing multi-chain sampling
1 parent 33fdad8 commit d8f5ab9

File tree

3 files changed

+19
-2
lines changed

3 files changed

+19
-2
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ uuid = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
33
keywords = ["markov chain monte carlo", "probabilistic programming"]
44
license = "MIT"
55
desc = "A lightweight interface for common MCMC methods."
6-
version = "5.7.2"
6+
version = "5.7.3"
77

88
[deps]
99
BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66"

src/sample.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -549,6 +549,7 @@ function mcmcsample(
549549
else
550550
initial_state[chainidx]
551551
end,
552+
chain_number=chainidx,
552553
kwargs...,
553554
)
554555
end
@@ -669,7 +670,7 @@ function mcmcsample(
669670
Distributed.@async begin
670671
try
671672
function sample_chain(
672-
seed, initial_params, initial_state, child_progress
673+
seed, initial_params, initial_state, child_progress, chainidx
673674
)
674675
# Seed a new random number generator with the pre-made seed.
675676
Random.seed!(rng, seed)
@@ -683,6 +684,7 @@ function mcmcsample(
683684
progress=child_progress,
684685
initial_params=initial_params,
685686
initial_state=initial_state,
687+
chain_number=chainidx,
686688
kwargs...,
687689
)
688690

@@ -696,6 +698,7 @@ function mcmcsample(
696698
_initial_params,
697699
_initial_state,
698700
child_progresses,
701+
1:nchains;
699702
)
700703
finally
701704
if progress == :overall
@@ -755,6 +758,7 @@ function mcmcsample(
755758
progressname=string(progressname, " (Chain ", i, " of ", nchains, ")"),
756759
initial_params=initial_params,
757760
initial_state=initial_state,
761+
chain_number=i,
758762
kwargs...,
759763
)
760764
end

test/sample.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -679,6 +679,19 @@
679679
@test all(chain[i].b == ref_chain[i].b for i in 1:N)
680680
end
681681

682+
@testset "chain_number keyword argument" begin
683+
@testset for m in [MCMCSerial(), MCMCThreads(), MCMCDistributed()]
684+
# check that the `chain_number` keyword argument is passed to the callback
685+
chain_numbers = Int[]
686+
function callback(args...; kwargs...)
687+
@test haskey(kwargs, :chain_number)
688+
return push!(chain_numbers, kwargs[:chain_number])
689+
end
690+
chain = sample(MyModel(), MySampler(), m, 10, 4; callback=callback)
691+
@test sort(chain_numbers) == repeat(1:4; inner=10)
692+
end
693+
end
694+
682695
@testset "Sample vector of `NamedTuple`s" begin
683696
chain = sample(MyModel(), MySampler(), 1_000; chain_type=Vector{NamedTuple})
684697
# Check output type

0 commit comments

Comments
 (0)