-
Notifications
You must be signed in to change notification settings - Fork 228
Gibbs sampler #2647
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Gibbs sampler #2647
Changes from all commits
c0158ea
a972b5a
bdb7f73
c3cc773
714c1e8
97c571d
94b723d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,173 @@ | ||
using DynamicPPL: VarName | ||
using Random: Random | ||
import AbstractMCMC | ||
|
||
# These functions provide specialized methods for GibbsConditional that extend the generic implementations in gibbs.jl | ||
|
||
""" | ||
GibbsConditional(sym::Symbol, conditional) | ||
|
||
A Gibbs sampler component that samples a variable according to a user-provided | ||
analytical conditional distribution. | ||
|
||
The `conditional` function should take a `NamedTuple` of conditioned variables and return | ||
a `Distribution` from which to sample the variable `sym`. | ||
|
||
# Examples | ||
|
||
```julia | ||
# Define a model | ||
@model function inverse_gdemo(x) | ||
λ ~ Gamma(2, 3) | ||
m ~ Normal(0, sqrt(1 / λ)) | ||
for i in 1:length(x) | ||
x[i] ~ Normal(m, sqrt(1 / λ)) | ||
end | ||
end | ||
|
||
# Define analytical conditionals | ||
function cond_λ(c::NamedTuple) | ||
a = 2.0 | ||
b = 3.0 | ||
m = c.m | ||
x = c.x | ||
n = length(x) | ||
a_new = a + (n + 1) / 2 | ||
b_new = b + sum((x[i] - m)^2 for i in 1:n) / 2 + m^2 / 2 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Likewise comparing to https://github.com/TuringLang/Turing.jl/blob/v0.35.5/src/mcmc/gibbs_conditional.jl, should the |
||
return Gamma(a_new, 1 / b_new) | ||
end | ||
|
||
function cond_m(c::NamedTuple) | ||
λ = c.λ | ||
x = c.x | ||
n = length(x) | ||
m_mean = sum(x) / (n + 1) | ||
m_var = 1 / (λ * (n + 1)) | ||
return Normal(m_mean, sqrt(m_var)) | ||
end | ||
|
||
# Sample using GibbsConditional | ||
model = inverse_gdemo([1.0, 2.0, 3.0]) | ||
chain = sample(model, Gibbs( | ||
:λ => GibbsConditional(:λ, cond_λ), | ||
:m => GibbsConditional(:m, cond_m) | ||
), 1000) | ||
``` | ||
""" | ||
struct GibbsConditional{C} <: InferenceAlgorithm | ||
conditional::C | ||
|
||
function GibbsConditional(sym::Symbol, conditional::C) where {C} | ||
return new{C}(conditional) | ||
end | ||
end | ||
|
||
# Mark GibbsConditional as a valid Gibbs component | ||
isgibbscomponent(::GibbsConditional) = true | ||
|
||
# Required methods for Gibbs constructor | ||
Base.length(::GibbsConditional) = 1 # Each GibbsConditional handles one variable | ||
|
||
""" | ||
find_global_varinfo(context, fallback_vi) | ||
|
||
Traverse the context stack to find global variable information from | ||
GibbsContext, ConditionContext, FixedContext, etc. | ||
""" | ||
function find_global_varinfo(context, fallback_vi) | ||
# Start with the given context and traverse down | ||
current_context = context | ||
|
||
while current_context !== nothing | ||
if current_context isa GibbsContext | ||
# Found GibbsContext, return its global varinfo | ||
return get_global_varinfo(current_context) | ||
elseif hasproperty(current_context, :childcontext) && | ||
isdefined(DynamicPPL, :childcontext) | ||
# Move to child context if it exists | ||
current_context = DynamicPPL.childcontext(current_context) | ||
else | ||
# No more child contexts | ||
break | ||
end | ||
end | ||
|
||
# If no GibbsContext found, use the fallback | ||
return fallback_vi | ||
end | ||
|
||
""" | ||
DynamicPPL.initialstep(rng, model, sampler::GibbsConditional, vi) | ||
|
||
Initialize the GibbsConditional sampler. | ||
""" | ||
function DynamicPPL.initialstep( | ||
rng::Random.AbstractRNG, | ||
model::DynamicPPL.Model, | ||
sampler::DynamicPPL.Sampler{<:GibbsConditional}, | ||
vi::DynamicPPL.AbstractVarInfo; | ||
kwargs..., | ||
) | ||
# GibbsConditional doesn't need any special initialization | ||
# Just return the initial state | ||
return nothing, vi | ||
end | ||
|
||
""" | ||
AbstractMCMC.step(rng, model, sampler::GibbsConditional, state) | ||
|
||
Perform a step of GibbsConditional sampling. | ||
""" | ||
function AbstractMCMC.step( | ||
rng::Random.AbstractRNG, | ||
model::DynamicPPL.Model, | ||
sampler::DynamicPPL.Sampler{<:GibbsConditional{S}}, | ||
state::DynamicPPL.AbstractVarInfo; | ||
kwargs..., | ||
) where {S} | ||
alg = sampler.alg | ||
|
||
# For GibbsConditional within Gibbs, we need to get all variable values | ||
# Traverse the context stack to find all conditioned/fixed/Gibbs variables | ||
global_vi = if hasproperty(model, :context) | ||
find_global_varinfo(model.context, state) | ||
else | ||
state | ||
end | ||
|
||
# Extract conditioned values as a NamedTuple | ||
# Include both random variables and observed data | ||
condvals_vars = DynamicPPL.values_as(DynamicPPL.invlink(global_vi, model), NamedTuple) | ||
condvals_obs = NamedTuple{keys(model.args)}(model.args) | ||
condvals = merge(condvals_vars, condvals_obs) | ||
|
||
# Get the conditional distribution | ||
conddist = alg.conditional(condvals) | ||
|
||
# Sample from the conditional distribution | ||
updated = rand(rng, conddist) | ||
|
||
# Update the variable in state | ||
# The Gibbs sampler ensures that state only contains one variable | ||
# Get the variable name from the keys | ||
varname = first(keys(state)) | ||
new_vi = DynamicPPL.setindex!!(state, updated, varname) | ||
|
||
return nothing, new_vi | ||
end | ||
|
||
""" | ||
setparams_varinfo!!(model, sampler::GibbsConditional, state, params::AbstractVarInfo) | ||
|
||
Update the variable info with new parameters for GibbsConditional. | ||
""" | ||
function setparams_varinfo!!( | ||
model::DynamicPPL.Model, | ||
sampler::DynamicPPL.Sampler{<:GibbsConditional}, | ||
state, | ||
params::DynamicPPL.AbstractVarInfo, | ||
) | ||
# For GibbsConditional, we just return the params as-is since | ||
# the state is nothing and we don't need to update anything | ||
return params | ||
end |
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -0,0 +1,78 @@ | ||||||
using Turing | ||||||
using Turing.Inference: GibbsConditional | ||||||
using Distributions | ||||||
using Random | ||||||
using Statistics | ||||||
|
||||||
# Test with the inverse gamma example from the issue | ||||||
@model function inverse_gdemo(x) | ||||||
λ ~ Gamma(2, 3) | ||||||
m ~ Normal(0, sqrt(1 / λ)) | ||||||
for i in 1:length(x) | ||||||
x[i] ~ Normal(m, sqrt(1 / λ)) | ||||||
end | ||||||
end | ||||||
|
||||||
# Define analytical conditionals | ||||||
function cond_λ(c::NamedTuple) | ||||||
a = 2.0 | ||||||
b = 3.0 | ||||||
m = c.m | ||||||
x = c.x | ||||||
n = length(x) | ||||||
a_new = a + (n + 1) / 2 | ||||||
b_new = b + sum((x[i] - m)^2 for i in 1:n) / 2 + m^2 / 2 | ||||||
return Gamma(a_new, 1 / b_new) | ||||||
end | ||||||
|
||||||
function cond_m(c::NamedTuple) | ||||||
λ = c.λ | ||||||
x = c.x | ||||||
n = length(x) | ||||||
m_mean = sum(x) / (n + 1) | ||||||
m_var = 1 / (λ * (n + 1)) | ||||||
return Normal(m_mean, sqrt(m_var)) | ||||||
end | ||||||
|
||||||
# Generate some observed data | ||||||
Random.seed!(42) | ||||||
x_obs = [1.0, 2.0, 3.0, 2.5, 1.5] | ||||||
|
||||||
# Create the model | ||||||
model = inverse_gdemo(x_obs) | ||||||
|
||||||
# Sample using GibbsConditional | ||||||
println("Testing GibbsConditional sampler...") | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You probably want various |
||||||
sampler = Gibbs(:λ => GibbsConditional(:λ, cond_λ), :m => GibbsConditional(:m, cond_m)) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [nitpick] The variable name is redundantly specified in both the Gibbs pair key and GibbsConditional constructor. Consider if this duplication is necessary or if the API could be simplified.
Suggested change
Copilot uses AI. Check for mistakes. Positive FeedbackNegative Feedback |
||||||
|
||||||
# Run a short chain to test | ||||||
chain = sample(model, sampler, 100) | ||||||
|
||||||
println("Sampling completed successfully!") | ||||||
println("\nChain summary:") | ||||||
println(chain) | ||||||
|
||||||
# Extract samples | ||||||
λ_samples = vec(chain[:λ]) | ||||||
m_samples = vec(chain[:m]) | ||||||
|
||||||
println("\nλ statistics:") | ||||||
println(" Mean: ", mean(λ_samples)) | ||||||
println(" Std: ", std(λ_samples)) | ||||||
println(" Min: ", minimum(λ_samples)) | ||||||
println(" Max: ", maximum(λ_samples)) | ||||||
|
||||||
println("\nm statistics:") | ||||||
println(" Mean: ", mean(m_samples)) | ||||||
println(" Std: ", std(m_samples)) | ||||||
println(" Min: ", minimum(m_samples)) | ||||||
println(" Max: ", maximum(m_samples)) | ||||||
|
||||||
# Test mixing with other samplers | ||||||
println("\n\nTesting mixed samplers...") | ||||||
sampler2 = Gibbs(:λ => GibbsConditional(:λ, cond_λ), :m => MH()) | ||||||
|
||||||
chain2 = sample(model, sampler2, 100) | ||||||
println("Mixed sampling completed successfully!") | ||||||
println("\nMixed chain summary:") | ||||||
println(chain2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Comparing to https://github.com/TuringLang/Turing.jl/blob/v0.35.5/src/mcmc/gibbs_conditional.jl, should the distribution be
Gamma(2, inv(3))
?