-
Notifications
You must be signed in to change notification settings - Fork 230
Implement GibbsConditional #2647
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR implements a GibbsConditional sampler component for Turing.jl that allows users to provide analytical conditional distributions for variables in Gibbs sampling. The implementation enables mixing user-defined conditional distributions with other MCMC samplers within the Gibbs framework.
Key changes:
- Added
GibbsConditionalstruct and supporting functions for analytical conditional sampling - Comprehensive test coverage for the new functionality
- Added example test file demonstrating usage
Reviewed Changes
Copilot reviewed 5 out of 6 changed files in this pull request and generated 5 comments.
Show a summary per file
| File | Description |
|---|---|
src/mcmc/gibbs_conditional.jl |
Core implementation of GibbsConditional sampler with step functions and variable handling |
src/mcmc/Inference.jl |
Added GibbsConditional export and module inclusion |
test/mcmc/gibbs.jl |
Added comprehensive test suite for GibbsConditional functionality |
test_gibbs_conditional.jl |
Example/demo file showing GibbsConditional usage |
HISTORY.md |
Version history update (unrelated to main feature) |
|
Turing.jl documentation for PR #2647 is available at: |
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## main #2647 +/- ##
==========================================
+ Coverage 86.47% 86.80% +0.32%
==========================================
Files 21 22 +1
Lines 1420 1455 +35
==========================================
+ Hits 1228 1263 +35
Misses 192 192 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Pull Request Test Coverage Report for Build 18095923386Warning: This coverage report may be inaccurate.This pull request's base commit is no longer the HEAD commit of its target branch. This means it includes changes from outside the original pull request, including, potentially, unrelated coverage changes.
Details
💛 - Coveralls |
mhauru
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As discussed on Slack, rather than doing a full review, I'm just going to give some high level comments and pointers for where to find more details on some of the relevant context.
mhauru
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I left a couple of new comments.
Also, a few of the comments from my first review remain to be addressed and all the newly added tests are still failing.
…ution and improve context variable retrieval
…actMCMC.step functions
|
Hi, I think this PR needs to be worked on as quickly as possible given that lots of people (#2547 and in slack) have been requesting this feature to be added back. I could take this over if nobody has bandwidth to work on this. However, I am a bit less familiar with the relevant parts in Turing. @mhauru @penelopeysm could you take a look if you have any suggestions? |
|
Aoife is no longer working on Turing, and I said I would take over this one, but just haven't found time. Let me try to see what the state of this is and report back. |
|
I fixed and improved the implementation and added a lot of tests. I think this may be done now, but I'll reread it myself tomorrow before requesting reviews. |
|
Julia package registry seems to be having trouble, will rerun CI in a moment. This is ready for review though. @Red-Portal, would you be happy to take an overall look? @penelopeysm could I ask you to take a look at the implementation and HISTORY.md entry at the least? Up to you if you want to also read the tests. |
src/mcmc/gibbs_conditional.jl
Outdated
| m = c[@varname(m)] | ||
| x = c[@varname(x)] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel that the c[@varname(x)] here is deceptively simple... it will work under the following scenarios
xis passed as a single vector to the model argumentsxis not in the arguments, but is conditioned on as a vector i.e.model() | (; x = vec)
it will fail in the case where
xis not in the arguments, but is conditioned on as individual elements i.e.model() | Dict(@varname(x[1]) => 1.0, @varname(x[2]) => 2.0)
Notice that because the model has x[i] ~ Normal, all three cases will work correctly with plain model evaluation. But one of the cases will fail with GibbsConditional.
I don't really know how to fix this, and I don't know whether it should even be fixed, but it makes me quite uncomfortable.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've fixed all the other points you raised. This will have to wait untill tomorrow.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As a simple enough solution, could we maybe take the model arguments out of the dictionary, and instead give c a field called c.model_args = deepcopy(model.args)? Then people can use c.model_args.x. Deepcopy would be needed to avoid accidental aliasing. (Or we could not copy, and leave it to the user to copy if they need it.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The outcome of this is:
xis passed as a single vector to the model arguments
OK, people can use c.model_args.x, clear.
xis not in the arguments, but is conditioned on as a vector i.e.model() | (; x = vec)
OK, people can use c[@varname(x)] because that's what they conditioned on.
xis not in the arguments, but is conditioned on as individual elements i.e.model() | Dict(@varname(x[1]) => 1.0, @varname(x[2]) => 2.0)
This will still fail but at least there's a good explanation for it, it's because they conditioned on x[1] and x[2] rather than x.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking at my previous comment, actually the original implementation is also explainable along the same lines. I guess the tldr is basically, if you supply x as a single thing, that's fine. So actually I'd be OK with leaving it as is.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The root cause of this is essentially the same as why you can't do
@model function f()
x ~ MvNormal()
end
m = condition(f(), Dict(@varname(x[1]) => 1.0))it's just that in that case you request the value of the conditioned x in the model, whereas here you request it in the conditionals function. So in some sense this is a bigger problem that needs to be fixed lower down (https://github.com/TuringLang/DynamicPPL.jl/issues/11480). So I'd be okay with leaving this be too.
The one thing that I was wondering about is that ConditionContext internally uses getvalue. If the user did the same in the conditionals function then at least they could get x1 = getvalue(c, @varname(x[1])) to work even if they conditioned on x as a whole. Thus I'm now thinking that maybe the best thing to do here would be to set a good example in the docstring and use getvalue rather than getindex, and leave it at that.
As a silver lining, if this goes wrong, at least the user gets a clear "key not found" error, and the error comes from code they wrote, in quite a traceable way.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, that sounds good to me! Agree that getvalue is better, albeit a tiny bit more annoying since it has to be imported from AbstractPPL
src/mcmc/gibbs_conditional.jl
Outdated
| # TODO(mhauru) Can we avoid invlinking all the time? Note that this causes a model | ||
| # evaluation, which may be expensive. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think for typed VarInfo it shouldn't need to evaluate the model. Obviously it still has a cost, just not as much as model evaluation.
julia> @model function f()
@info "hi"
x ~ Normal()
end
f (generic function with 2 methods)
julia> model = f(); v = VarInfo(model);
[ Info: hi
julia> v2 = DynamicPPL.link!!(v, model); v3 = DynamicPPL.invlink!!(v, model);There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Adjusted the comment to reflect this.
src/mcmc/gibbs_conditional.jl
Outdated
| (DynamicPPL.VarName{sym}() => val for (sym, val) in pairs(cond_nt))..., | ||
| (DynamicPPL.VarName{sym}() => val for (sym, val) in pairs(fixed_nt))..., | ||
| (DynamicPPL.VarName{sym}() => val for (sym, val) in pairs(model.args))..., |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cond_nt and fixed_nt might not be NamedTuples, they might be dicts, in which case this will fail in a very weird manner.
julia> VarName{@varname(x)}()
x
julia> VarName{@varname(x)}() == @varname(x) # not the same thing
falseI think you have to convert them to dicts first. DynamicPPL.to_varname_dict will do this (albeit inefficiently TuringLang/DynamicPPL.jl#1134).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice catch. I fixed this and added a test that would have caught this.
Making so many Dicts is sad, but for now I just want to get this working, and worry about performance later.
| prior_var = 100.0 # 10^2 | ||
| post_var = 1 / (1 / prior_var + n / var) | ||
| post_mean = post_var * (0 / prior_var + sum(x) / var) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| prior_var = 100.0 # 10^2 | |
| post_var = 1 / (1 / prior_var + n / var) | |
| post_mean = post_var * (0 / prior_var + sum(x) / var) | |
| prior_var = 100.0 # 10^2 | |
| post_var = 1 / (1 / prior_var + n / var) | |
| post_mean = post_var * (sum(x) / var) |
again - are there formulas for this somewhere? Maybe a statistician looks at this and goes 'ah yes of course', but I feel quite uncomfortable.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will add the link. I didn't write this part, but I think it's deliberately in this format to match the formula I found e.g. on wikipedia. The 0 I think is the mean of the prior.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah okay in that case it can stay!
penelopeysm
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, great work!
WRT: #2547