-
Notifications
You must be signed in to change notification settings - Fork 230
Implement GibbsConditional #2647
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
AoifeHughes
wants to merge
26
commits into
main
Choose a base branch
from
gibbs-sampler
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 3 commits
Commits
Show all changes
26 commits
Select commit
Hold shift + click to select a range
c0158ea
Add GibbsConditional sampler and corresponding tests
a972b5a
clarified comment
bdb7f73
Merge branch 'main' into gibbs-sampler
mhauru c3cc773
add MHs suggestions
714c1e8
formatter
97c571d
Merge branch 'gibbs-sampler' of github.com:TuringLang/Turing.jl into …
94b723d
fixed exporting thing
891ac14
Merge branch 'main' into gibbs-sampler
AoifeHughes 2058ae5
Refactor Gibbs sampler to use inverse of parameters for Gamma distrib…
b0812a3
removed file added by mistake
d910312
Add safety checks and error handling in find_global_varinfo and Abstr…
4b1dc2f
imports?
1e84309
Merge remote-tracking branch 'origin/main' into gibbs-sampler
mhauru a33d8a9
Fixes and improvements for GibbsConditional
mhauru f41fc6e
Move GibbsConditional tests to their own file
mhauru bd5ff0b
More GibbsConditional tests
mhauru 34acad7
Bump patch version to 0.41.2, add HISTORY.md entry
mhauru 4786a59
Remove spurious change
mhauru 8951d98
Code style and documentation
mhauru 45ab5f8
Add one test_throws, tweak test thresholds and dimensions
mhauru 805bc60
Apply suggestions from code review
mhauru d0c3cf4
Add links for where to get analytical posteriors
mhauru 98f4213
Update TODO note
mhauru c74b0a0
Fix a GibbsConditional bug, add a test
mhauru 4a7d08c
Set seeds better
mhauru 744d254
Use getvalue in docstring
mhauru File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,245 @@ | ||
| using DynamicPPL: VarName | ||
| using Random: Random | ||
| import AbstractMCMC | ||
|
|
||
| # These functions provide specialized methods for GibbsConditional that extend the generic implementations in gibbs.jl | ||
|
|
||
| """ | ||
| GibbsConditional(sym::Symbol, conditional) | ||
|
|
||
| A Gibbs sampler component that samples a variable according to a user-provided | ||
| analytical conditional distribution. | ||
|
|
||
| The `conditional` function should take a `NamedTuple` of conditioned variables and return | ||
| a `Distribution` from which to sample the variable `sym`. | ||
|
|
||
| # Examples | ||
|
|
||
| ```julia | ||
| # Define a model | ||
| @model function inverse_gdemo(x) | ||
| λ ~ Gamma(2, 3) | ||
| m ~ Normal(0, sqrt(1 / λ)) | ||
| for i in 1:length(x) | ||
| x[i] ~ Normal(m, sqrt(1 / λ)) | ||
| end | ||
| end | ||
|
|
||
| # Define analytical conditionals | ||
| function cond_λ(c::NamedTuple) | ||
| a = 2.0 | ||
| b = 3.0 | ||
| m = c.m | ||
| x = c.x | ||
| n = length(x) | ||
| a_new = a + (n + 1) / 2 | ||
| b_new = b + sum((x[i] - m)^2 for i in 1:n) / 2 + m^2 / 2 | ||
mhauru marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
mhauru marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| return Gamma(a_new, 1 / b_new) | ||
| end | ||
|
|
||
| function cond_m(c::NamedTuple) | ||
| λ = c.λ | ||
| x = c.x | ||
| n = length(x) | ||
| m_mean = sum(x) / (n + 1) | ||
| m_var = 1 / (λ * (n + 1)) | ||
| return Normal(m_mean, sqrt(m_var)) | ||
| end | ||
|
|
||
| # Sample using GibbsConditional | ||
| model = inverse_gdemo([1.0, 2.0, 3.0]) | ||
| chain = sample(model, Gibbs( | ||
| :λ => GibbsConditional(:λ, cond_λ), | ||
| :m => GibbsConditional(:m, cond_m) | ||
| ), 1000) | ||
| ``` | ||
| """ | ||
| struct GibbsConditional{S,C} <: InferenceAlgorithm | ||
mhauru marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| conditional::C | ||
|
|
||
| function GibbsConditional(sym::Symbol, conditional::C) where {C} | ||
| return new{sym,C}(conditional) | ||
| end | ||
| end | ||
|
|
||
| # Mark GibbsConditional as a valid Gibbs component | ||
| isgibbscomponent(::GibbsConditional) = true | ||
|
|
||
| """ | ||
| DynamicPPL.initialstep(rng, model, sampler::GibbsConditional, vi) | ||
|
|
||
| Initialize the GibbsConditional sampler. | ||
| """ | ||
| function DynamicPPL.initialstep( | ||
| rng::Random.AbstractRNG, | ||
| model::DynamicPPL.Model, | ||
| sampler::DynamicPPL.Sampler{<:GibbsConditional}, | ||
| vi::DynamicPPL.AbstractVarInfo; | ||
| kwargs..., | ||
| ) | ||
| # GibbsConditional doesn't need any special initialization | ||
| # Just return the initial state | ||
| return nothing, vi | ||
| end | ||
|
|
||
| """ | ||
| AbstractMCMC.step(rng, model, sampler::GibbsConditional, state) | ||
|
|
||
| Perform a step of GibbsConditional sampling. | ||
| """ | ||
| function AbstractMCMC.step( | ||
| rng::Random.AbstractRNG, | ||
| model::DynamicPPL.Model, | ||
| sampler::DynamicPPL.Sampler{<:GibbsConditional{S}}, | ||
| state::DynamicPPL.AbstractVarInfo; | ||
| kwargs..., | ||
| ) where {S} | ||
| alg = sampler.alg | ||
|
|
||
| # For GibbsConditional within Gibbs, we need to get all variable values | ||
| # Check if we're in a Gibbs context | ||
| global_vi = if hasproperty(model, :context) && model.context isa GibbsContext | ||
mhauru marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
mhauru marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| # We're in a Gibbs context, get the global varinfo | ||
| get_global_varinfo(model.context) | ||
| else | ||
| # We're not in a Gibbs context, use the current state | ||
| state | ||
| end | ||
|
|
||
| # Extract conditioned values as a NamedTuple | ||
| # Include both random variables and observed data | ||
| condvals_vars = DynamicPPL.values_as(DynamicPPL.invlink(global_vi, model), NamedTuple) | ||
| condvals_obs = NamedTuple{keys(model.args)}(model.args) | ||
| condvals = merge(condvals_vars, condvals_obs) | ||
|
|
||
| # Get the conditional distribution | ||
| conddist = alg.conditional(condvals) | ||
|
|
||
| # Sample from the conditional distribution | ||
| updated = rand(rng, conddist) | ||
|
|
||
| # Update the variable in state | ||
| # We need to get the actual VarName for this variable | ||
mhauru marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| # The symbol S tells us which variable to update | ||
| vn = VarName{S}() | ||
|
|
||
| # Check if the variable needs to be a vector | ||
| new_vi = if haskey(state, vn) | ||
| # Update the existing variable | ||
| DynamicPPL.setindex!!(state, updated, vn) | ||
| else | ||
| # Try to find the variable with indices | ||
| # This handles cases where the variable might have indices | ||
| local updated_vi = state | ||
mhauru marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| found = false | ||
| for key in keys(state) | ||
| if DynamicPPL.getsym(key) == S | ||
| updated_vi = DynamicPPL.setindex!!(state, updated, key) | ||
| found = true | ||
| break | ||
| end | ||
| end | ||
| if !found | ||
| error("Could not find variable $S in VarInfo") | ||
mhauru marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| end | ||
| updated_vi | ||
| end | ||
|
|
||
| # Update log joint probability | ||
| new_vi = last(DynamicPPL.evaluate!!(model, new_vi, DynamicPPL.DefaultContext())) | ||
mhauru marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| return nothing, new_vi | ||
| end | ||
|
|
||
| """ | ||
| setparams_varinfo!!(model, sampler::GibbsConditional, state, params::AbstractVarInfo) | ||
|
|
||
| Update the variable info with new parameters for GibbsConditional. | ||
| """ | ||
| function setparams_varinfo!!( | ||
| model::DynamicPPL.Model, | ||
| sampler::DynamicPPL.Sampler{<:GibbsConditional}, | ||
| state, | ||
| params::DynamicPPL.AbstractVarInfo, | ||
| ) | ||
| # For GibbsConditional, we just return the params as-is since | ||
| # the state is nothing and we don't need to update anything | ||
| return params | ||
| end | ||
|
|
||
| """ | ||
| gibbs_initialstep_recursive( | ||
| rng, model, sampler::GibbsConditional, target_varnames, global_vi, prev_state | ||
| ) | ||
|
|
||
| Initialize the GibbsConditional sampler. | ||
| """ | ||
| function gibbs_initialstep_recursive( | ||
| rng::Random.AbstractRNG, | ||
| model::DynamicPPL.Model, | ||
| sampler_wrapped::DynamicPPL.Sampler{<:GibbsConditional}, | ||
| target_varnames::AbstractVector{<:VarName}, | ||
| global_vi::DynamicPPL.AbstractVarInfo, | ||
| prev_state, | ||
| ) | ||
| # GibbsConditional doesn't need any special initialization | ||
| # Just perform one sampling step | ||
| return gibbs_step_recursive( | ||
| rng, model, sampler_wrapped, target_varnames, global_vi, nothing | ||
| ) | ||
| end | ||
|
|
||
| """ | ||
| gibbs_step_recursive( | ||
| rng, model, sampler::GibbsConditional, target_varnames, global_vi, state | ||
| ) | ||
|
|
||
| Perform a single step of GibbsConditional sampling. | ||
| """ | ||
| function gibbs_step_recursive( | ||
mhauru marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| rng::Random.AbstractRNG, | ||
| model::DynamicPPL.Model, | ||
| sampler_wrapped::DynamicPPL.Sampler{<:GibbsConditional{S}}, | ||
| target_varnames::AbstractVector{<:VarName}, | ||
| global_vi::DynamicPPL.AbstractVarInfo, | ||
| state, | ||
| ) where {S} | ||
| sampler = sampler_wrapped.alg | ||
|
|
||
| # Extract conditioned values as a NamedTuple | ||
| # Include both random variables and observed data | ||
| condvals_vars = DynamicPPL.values_as(DynamicPPL.invlink(global_vi, model), NamedTuple) | ||
| condvals_obs = NamedTuple{keys(model.args)}(model.args) | ||
| condvals = merge(condvals_vars, condvals_obs) | ||
|
|
||
| # Get the conditional distribution | ||
| conddist = sampler.conditional(condvals) | ||
|
|
||
| # Sample from the conditional distribution | ||
| updated = rand(rng, conddist) | ||
|
|
||
| # Update the variable in global_vi | ||
| # We need to get the actual VarName for this variable | ||
| # The symbol S tells us which variable to update | ||
| vn = VarName{S}() | ||
|
|
||
| # Check if the variable needs to be a vector | ||
| if haskey(global_vi, vn) | ||
| # Update the existing variable | ||
| global_vi = DynamicPPL.setindex!!(global_vi, updated, vn) | ||
| else | ||
| # Try to find the variable with indices | ||
| # This handles cases where the variable might have indices | ||
| for key in keys(global_vi) | ||
| if DynamicPPL.getsym(key) == S | ||
| global_vi = DynamicPPL.setindex!!(global_vi, updated, key) | ||
| break | ||
| end | ||
| end | ||
| end | ||
|
|
||
| # Update log joint probability | ||
| global_vi = last(DynamicPPL.evaluate!!(model, global_vi, DynamicPPL.DefaultContext())) | ||
|
|
||
| return nothing, global_vi | ||
| end | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.