-
Notifications
You must be signed in to change notification settings - Fork 228
Gibbs sampler #2647
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Gibbs sampler #2647
Changes from 2 commits
c0158ea
a972b5a
bdb7f73
c3cc773
714c1e8
97c571d
94b723d
891ac14
2058ae5
b0812a3
d910312
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||
---|---|---|---|---|---|---|---|---|
@@ -0,0 +1,245 @@ | ||||||||
using DynamicPPL: VarName | ||||||||
using Random: Random | ||||||||
import AbstractMCMC | ||||||||
|
||||||||
# These functions provide specialized methods for GibbsConditional that extend the generic implementations in gibbs.jl | ||||||||
|
||||||||
""" | ||||||||
GibbsConditional(sym::Symbol, conditional) | ||||||||
|
||||||||
A Gibbs sampler component that samples a variable according to a user-provided | ||||||||
analytical conditional distribution. | ||||||||
|
||||||||
The `conditional` function should take a `NamedTuple` of conditioned variables and return | ||||||||
a `Distribution` from which to sample the variable `sym`. | ||||||||
|
||||||||
# Examples | ||||||||
|
||||||||
```julia | ||||||||
# Define a model | ||||||||
@model function inverse_gdemo(x) | ||||||||
λ ~ Gamma(2, 3) | ||||||||
m ~ Normal(0, sqrt(1 / λ)) | ||||||||
for i in 1:length(x) | ||||||||
x[i] ~ Normal(m, sqrt(1 / λ)) | ||||||||
end | ||||||||
end | ||||||||
|
||||||||
# Define analytical conditionals | ||||||||
function cond_λ(c::NamedTuple) | ||||||||
a = 2.0 | ||||||||
b = 3.0 | ||||||||
m = c.m | ||||||||
x = c.x | ||||||||
n = length(x) | ||||||||
a_new = a + (n + 1) / 2 | ||||||||
b_new = b + sum((x[i] - m)^2 for i in 1:n) / 2 + m^2 / 2 | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Likewise comparing to https://github.com/TuringLang/Turing.jl/blob/v0.35.5/src/mcmc/gibbs_conditional.jl, should the |
||||||||
return Gamma(a_new, 1 / b_new) | ||||||||
end | ||||||||
|
||||||||
function cond_m(c::NamedTuple) | ||||||||
λ = c.λ | ||||||||
x = c.x | ||||||||
n = length(x) | ||||||||
m_mean = sum(x) / (n + 1) | ||||||||
m_var = 1 / (λ * (n + 1)) | ||||||||
return Normal(m_mean, sqrt(m_var)) | ||||||||
end | ||||||||
|
||||||||
# Sample using GibbsConditional | ||||||||
model = inverse_gdemo([1.0, 2.0, 3.0]) | ||||||||
chain = sample(model, Gibbs( | ||||||||
:λ => GibbsConditional(:λ, cond_λ), | ||||||||
:m => GibbsConditional(:m, cond_m) | ||||||||
), 1000) | ||||||||
``` | ||||||||
""" | ||||||||
struct GibbsConditional{S,C} <: InferenceAlgorithm | ||||||||
|
||||||||
conditional::C | ||||||||
|
||||||||
function GibbsConditional(sym::Symbol, conditional::C) where {C} | ||||||||
return new{sym,C}(conditional) | ||||||||
end | ||||||||
end | ||||||||
|
||||||||
# Mark GibbsConditional as a valid Gibbs component | ||||||||
isgibbscomponent(::GibbsConditional) = true | ||||||||
|
||||||||
""" | ||||||||
DynamicPPL.initialstep(rng, model, sampler::GibbsConditional, vi) | ||||||||
|
||||||||
Initialize the GibbsConditional sampler. | ||||||||
""" | ||||||||
function DynamicPPL.initialstep( | ||||||||
rng::Random.AbstractRNG, | ||||||||
model::DynamicPPL.Model, | ||||||||
sampler::DynamicPPL.Sampler{<:GibbsConditional}, | ||||||||
vi::DynamicPPL.AbstractVarInfo; | ||||||||
kwargs..., | ||||||||
) | ||||||||
# GibbsConditional doesn't need any special initialization | ||||||||
# Just return the initial state | ||||||||
return nothing, vi | ||||||||
end | ||||||||
|
||||||||
""" | ||||||||
AbstractMCMC.step(rng, model, sampler::GibbsConditional, state) | ||||||||
|
||||||||
Perform a step of GibbsConditional sampling. | ||||||||
""" | ||||||||
function AbstractMCMC.step( | ||||||||
rng::Random.AbstractRNG, | ||||||||
model::DynamicPPL.Model, | ||||||||
sampler::DynamicPPL.Sampler{<:GibbsConditional{S}}, | ||||||||
state::DynamicPPL.AbstractVarInfo; | ||||||||
kwargs..., | ||||||||
) where {S} | ||||||||
alg = sampler.alg | ||||||||
|
||||||||
# For GibbsConditional within Gibbs, we need to get all variable values | ||||||||
# Check if we're in a Gibbs context | ||||||||
global_vi = if hasproperty(model, :context) && model.context isa GibbsContext | ||||||||
|
global_vi = if hasproperty(model, :context) && model.context isa GibbsContext | |
global_vi = if isdefined(model, :context) && model.context isa GibbsContext |
Copilot uses AI. Check for mistakes.
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The core idea here of finding the possible GibbsContext
and getting the global varinfo from it is good. However, GibbsContext is a bit weird, in that it's always inserted at the bottom of the context stack. By the context stack I mean the fact that contexts often have child contexts, and thus model.context
may in fact be many nested contexts. See e.g. how the GibbsContext is set here, by calling setleafcontext
rather than setcontext
:
Line 258 in d75e6f2
gibbs_context = DynamicPPL.setleafcontext(model.context, gibbs_context_inner) |
So rather than check whether model.context isa GibbsContext
, I think you'll need to traverse the whole context stack, and check if any of them are a GibbsContext
, until you hit a leaf context and the stack ends.
Moreover, I think you'll need to check not just for GibbsContext, but also for ConditionContext
and FixedContext
, which condition/fix the values of some variables. So all in all, if you go through the whole stack, starting with model.context
and going through its child contexts, and collect any variables set in ConditionContext
, FixedContext
, and GibbsContext
, that should give you all of the variable values you need. See here for more details on condition and fix: https://github.com/TuringLang/DynamicPPL.jl/blob/1ed8cc8d9f013f46806c88a83e93f7a4c5b891dd/src/contexts.jl#L258
As mentioned on Slack a week or two ago, all this context stack business is likely changing Soon (TM), since @penelopeysm is overhauling condition
and fix
over here, TuringLang/DynamicPPL.jl#1010, and as a result we may be able to overhaul GibbsContext
as well. You could wait for that to be finished first, at least if it looks like getting this to work would be a lot of work.
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The operating principle of the new(ish) Gibbs sampler is that every component sampler only ever sees a VarInfo with the variables that that component sampler is supposed to sample. Thus, you should be able to assume that updated
includes values for all the variables in state
, and for nothing else. Hence the below checks and loops I think shouldn't be necessary. The solution be might be as simple as new_state = unflatten(state, updated)
, though there may be details there that I'm not thinking of right now. (What if state
is linked? But maybe we can guarantee that it's never linked, because the sampler can control it.) Happy to discuss details more if unflatten
by itself doesn't seem to cut it.
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The local
keyword is unnecessary here since updated_vi
is already in a local scope. This adds visual clutter without functional benefit.
local updated_vi = state | |
updated_vi = state |
Copilot uses AI. Check for mistakes.
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The error message could be more helpful by suggesting what variables are available or providing debugging information about the VarInfo contents.
error("Could not find variable $S in VarInfo") | |
error("Could not find variable $S in VarInfo. Available variables: $(join([string(DynamicPPL.getsym(k)) for k in keys(state)], \", \")).") |
Copilot uses AI. Check for mistakes.
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you shouldn't need this, because the log joint is going to be recomputed anyway by the Gibbs sampler once it's looped over all component samplers. Saves one model evaluation.
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would hope you wouldn't need to overload this or gibbs_initialstep_recursive
. Also, the below implementation seems to be just a repeat of step
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Comparing to https://github.com/TuringLang/Turing.jl/blob/v0.35.5/src/mcmc/gibbs_conditional.jl, should the distribution be
Gamma(2, inv(3))
?