-
Notifications
You must be signed in to change notification settings - Fork 19
Add chain_number
keyword argument when performing multi-chain sampling
#174
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
AbstractMCMC.jl documentation for PR #174 is available at: |
d250afa
to
982e473
Compare
@testset for m in [MCMCSerial(), MCMCThreads(), MCMCDistributed()] | ||
niters = 10 | ||
channel = Channel{Int}() do chn | ||
# check that the `chain_number` keyword argument is passed to the callback | ||
function callback(args...; kwargs...) | ||
@test haskey(kwargs, :chain_number) | ||
return put!(chn, kwargs[:chain_number]) | ||
end | ||
chain = sample(MyModel(), MySampler(), m, niters, 4; callback=callback) | ||
end | ||
chain_numbers = collect(channel) | ||
@test sort(chain_numbers) == repeat(1:4; inner=niters) | ||
end | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@joelkandiah maybe not surprising but the original naive version of this test just pushed to a vector inside the callback, and it would randomly fail with MCMCThreads (sometimes it would only pick up 39 entries rather than 40, so I guess there was some race condition...) this seems to work around it fine -- thought you might be interested
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see no issue, make a ton of sense (semver tripped me a bit, battling with myself whether patch was enough, but agree minor release the right way to go).
Thanks! I know, it's quite tricky to decide sometimes 😅 |
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## main #174 +/- ##
===========================
===========================
☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Callbacks in general have no knowledge of which chain is being sampled, meaning that there is no way to differentiate callbacks being called from different chains:
AbstractMCMC.jl/src/sample.jl
Lines 235 to 237 in 33fdad8
This PR makes the multiple-chain
sample
method pass a keyword argument,chain_number
, to the single-chainsample
. In turn this keyword argument is propagated to the callback so it can be accessed.IMO the better solution is to actually make it a positional argument to the callback (when sampling a single chain it can either just be
1
ormissing
) but that would be a breaking change and I'm unsure if we want to do that in AbstractMCMC just for this. A keyword argument would just be a minor version bump. Thoughts welcome!