Skip to content

Gibbs sampler #2647

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/mcmc/Inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ export InferenceAlgorithm,
ESS,
Emcee,
Gibbs, # classic sampling
GibbsConditional, # conditional sampling
HMC,
SGLD,
PolynomialStepsize,
Expand Down Expand Up @@ -392,6 +393,7 @@ include("mh.jl")
include("is.jl")
include("particle_mcmc.jl")
include("gibbs.jl")
include("gibbs_conditional.jl")
include("sghmc.jl")
include("emcee.jl")
include("prior.jl")
Expand Down
245 changes: 245 additions & 0 deletions src/mcmc/gibbs_conditional.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,245 @@
using DynamicPPL: VarName
using Random: Random
import AbstractMCMC

# These functions provide specialized methods for GibbsConditional that extend the generic implementations in gibbs.jl

"""
GibbsConditional(sym::Symbol, conditional)

A Gibbs sampler component that samples a variable according to a user-provided
analytical conditional distribution.

The `conditional` function should take a `NamedTuple` of conditioned variables and return
a `Distribution` from which to sample the variable `sym`.

# Examples

```julia
# Define a model
@model function inverse_gdemo(x)
λ ~ Gamma(2, 3)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comparing to https://github.com/TuringLang/Turing.jl/blob/v0.35.5/src/mcmc/gibbs_conditional.jl, should the distribution be Gamma(2, inv(3))?

m ~ Normal(0, sqrt(1 / λ))
for i in 1:length(x)
x[i] ~ Normal(m, sqrt(1 / λ))
end
end

# Define analytical conditionals
function cond_λ(c::NamedTuple)
a = 2.0
b = 3.0
m = c.m
x = c.x
n = length(x)
a_new = a + (n + 1) / 2
b_new = b + sum((x[i] - m)^2 for i in 1:n) / 2 + m^2 / 2
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Likewise comparing to https://github.com/TuringLang/Turing.jl/blob/v0.35.5/src/mcmc/gibbs_conditional.jl, should the m in the variance term rather be the mean of x?

return Gamma(a_new, 1 / b_new)
end

function cond_m(c::NamedTuple)
λ = c.λ
x = c.x
n = length(x)
m_mean = sum(x) / (n + 1)
m_var = 1 / (λ * (n + 1))
return Normal(m_mean, sqrt(m_var))
end

# Sample using GibbsConditional
model = inverse_gdemo([1.0, 2.0, 3.0])
chain = sample(model, Gibbs(
:λ => GibbsConditional(:λ, cond_λ),
:m => GibbsConditional(:m, cond_m)
), 1000)
```
"""
struct GibbsConditional{S,C} <: InferenceAlgorithm
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the type parameter S shouldn't be necessary once we don't use it any more to construct the VarName that is being sampled. See below comments for more details.

conditional::C

function GibbsConditional(sym::Symbol, conditional::C) where {C}
return new{sym,C}(conditional)
end
end

# Mark GibbsConditional as a valid Gibbs component
isgibbscomponent(::GibbsConditional) = true

"""
DynamicPPL.initialstep(rng, model, sampler::GibbsConditional, vi)

Initialize the GibbsConditional sampler.
"""
function DynamicPPL.initialstep(
rng::Random.AbstractRNG,
model::DynamicPPL.Model,
sampler::DynamicPPL.Sampler{<:GibbsConditional},
vi::DynamicPPL.AbstractVarInfo;
kwargs...,
)
# GibbsConditional doesn't need any special initialization
# Just return the initial state
return nothing, vi
end

"""
AbstractMCMC.step(rng, model, sampler::GibbsConditional, state)

Perform a step of GibbsConditional sampling.
"""
function AbstractMCMC.step(
rng::Random.AbstractRNG,
model::DynamicPPL.Model,
sampler::DynamicPPL.Sampler{<:GibbsConditional{S}},
state::DynamicPPL.AbstractVarInfo;
kwargs...,
) where {S}
alg = sampler.alg

# For GibbsConditional within Gibbs, we need to get all variable values
# Check if we're in a Gibbs context
global_vi = if hasproperty(model, :context) && model.context isa GibbsContext
Copy link
Preview

Copilot AI Aug 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using hasproperty for checking if a field exists is fragile and could break with future changes to the model structure. Consider using a more robust method like isdefined or checking the model type directly.

Suggested change
global_vi = if hasproperty(model, :context) && model.context isa GibbsContext
global_vi = if isdefined(model, :context) && model.context isa GibbsContext

Copilot uses AI. Check for mistakes.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The core idea here of finding the possible GibbsContext and getting the global varinfo from it is good. However, GibbsContext is a bit weird, in that it's always inserted at the bottom of the context stack. By the context stack I mean the fact that contexts often have child contexts, and thus model.context may in fact be many nested contexts. See e.g. how the GibbsContext is set here, by calling setleafcontext rather than setcontext:

gibbs_context = DynamicPPL.setleafcontext(model.context, gibbs_context_inner)

So rather than check whether model.context isa GibbsContext, I think you'll need to traverse the whole context stack, and check if any of them are a GibbsContext, until you hit a leaf context and the stack ends.

Moreover, I think you'll need to check not just for GibbsContext, but also for ConditionContext and FixedContext, which condition/fix the values of some variables. So all in all, if you go through the whole stack, starting with model.context and going through its child contexts, and collect any variables set in ConditionContext, FixedContext, and GibbsContext, that should give you all of the variable values you need. See here for more details on condition and fix: https://github.com/TuringLang/DynamicPPL.jl/blob/1ed8cc8d9f013f46806c88a83e93f7a4c5b891dd/src/contexts.jl#L258

As mentioned on Slack a week or two ago, all this context stack business is likely changing Soon (TM), since @penelopeysm is overhauling condition and fix over here, TuringLang/DynamicPPL.jl#1010, and as a result we may be able to overhaul GibbsContext as well. You could wait for that to be finished first, at least if it looks like getting this to work would be a lot of work.

# We're in a Gibbs context, get the global varinfo
get_global_varinfo(model.context)
else
# We're not in a Gibbs context, use the current state
state
end

# Extract conditioned values as a NamedTuple
# Include both random variables and observed data
condvals_vars = DynamicPPL.values_as(DynamicPPL.invlink(global_vi, model), NamedTuple)
condvals_obs = NamedTuple{keys(model.args)}(model.args)
condvals = merge(condvals_vars, condvals_obs)

# Get the conditional distribution
conddist = alg.conditional(condvals)

# Sample from the conditional distribution
updated = rand(rng, conddist)

# Update the variable in state
# We need to get the actual VarName for this variable
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The operating principle of the new(ish) Gibbs sampler is that every component sampler only ever sees a VarInfo with the variables that that component sampler is supposed to sample. Thus, you should be able to assume that updated includes values for all the variables in state, and for nothing else. Hence the below checks and loops I think shouldn't be necessary. The solution be might be as simple as new_state = unflatten(state, updated), though there may be details there that I'm not thinking of right now. (What if state is linked? But maybe we can guarantee that it's never linked, because the sampler can control it.) Happy to discuss details more if unflatten by itself doesn't seem to cut it.

# The symbol S tells us which variable to update
vn = VarName{S}()

# Check if the variable needs to be a vector
new_vi = if haskey(state, vn)
# Update the existing variable
DynamicPPL.setindex!!(state, updated, vn)
else
# Try to find the variable with indices
# This handles cases where the variable might have indices
local updated_vi = state
Copy link
Preview

Copilot AI Aug 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The local keyword is unnecessary here since updated_vi is already in a local scope. This adds visual clutter without functional benefit.

Suggested change
local updated_vi = state
updated_vi = state

Copilot uses AI. Check for mistakes.

found = false
for key in keys(state)
if DynamicPPL.getsym(key) == S
updated_vi = DynamicPPL.setindex!!(state, updated, key)
found = true
break
end
end
if !found
error("Could not find variable $S in VarInfo")
Copy link
Preview

Copilot AI Aug 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The error message could be more helpful by suggesting what variables are available or providing debugging information about the VarInfo contents.

Suggested change
error("Could not find variable $S in VarInfo")
error("Could not find variable $S in VarInfo. Available variables: $(join([string(DynamicPPL.getsym(k)) for k in keys(state)], \", \")).")

Copilot uses AI. Check for mistakes.

end
updated_vi
end

# Update log joint probability
new_vi = last(DynamicPPL.evaluate!!(model, new_vi, DynamicPPL.DefaultContext()))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you shouldn't need this, because the log joint is going to be recomputed anyway by the Gibbs sampler once it's looped over all component samplers. Saves one model evaluation.


return nothing, new_vi
end

"""
setparams_varinfo!!(model, sampler::GibbsConditional, state, params::AbstractVarInfo)

Update the variable info with new parameters for GibbsConditional.
"""
function setparams_varinfo!!(
model::DynamicPPL.Model,
sampler::DynamicPPL.Sampler{<:GibbsConditional},
state,
params::DynamicPPL.AbstractVarInfo,
)
# For GibbsConditional, we just return the params as-is since
# the state is nothing and we don't need to update anything
return params
end

"""
gibbs_initialstep_recursive(
rng, model, sampler::GibbsConditional, target_varnames, global_vi, prev_state
)

Initialize the GibbsConditional sampler.
"""
function gibbs_initialstep_recursive(
rng::Random.AbstractRNG,
model::DynamicPPL.Model,
sampler_wrapped::DynamicPPL.Sampler{<:GibbsConditional},
target_varnames::AbstractVector{<:VarName},
global_vi::DynamicPPL.AbstractVarInfo,
prev_state,
)
# GibbsConditional doesn't need any special initialization
# Just perform one sampling step
return gibbs_step_recursive(
rng, model, sampler_wrapped, target_varnames, global_vi, nothing
)
end

"""
gibbs_step_recursive(
rng, model, sampler::GibbsConditional, target_varnames, global_vi, state
)

Perform a single step of GibbsConditional sampling.
"""
function gibbs_step_recursive(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would hope you wouldn't need to overload this or gibbs_initialstep_recursive. Also, the below implementation seems to be just a repeat of step.

rng::Random.AbstractRNG,
model::DynamicPPL.Model,
sampler_wrapped::DynamicPPL.Sampler{<:GibbsConditional{S}},
target_varnames::AbstractVector{<:VarName},
global_vi::DynamicPPL.AbstractVarInfo,
state,
) where {S}
sampler = sampler_wrapped.alg

# Extract conditioned values as a NamedTuple
# Include both random variables and observed data
condvals_vars = DynamicPPL.values_as(DynamicPPL.invlink(global_vi, model), NamedTuple)
condvals_obs = NamedTuple{keys(model.args)}(model.args)
condvals = merge(condvals_vars, condvals_obs)

# Get the conditional distribution
conddist = sampler.conditional(condvals)

# Sample from the conditional distribution
updated = rand(rng, conddist)

# Update the variable in global_vi
# We need to get the actual VarName for this variable
# The symbol S tells us which variable to update
vn = VarName{S}()

# Check if the variable needs to be a vector
if haskey(global_vi, vn)
# Update the existing variable
global_vi = DynamicPPL.setindex!!(global_vi, updated, vn)
else
# Try to find the variable with indices
# This handles cases where the variable might have indices
for key in keys(global_vi)
if DynamicPPL.getsym(key) == S
global_vi = DynamicPPL.setindex!!(global_vi, updated, key)
break
end
end
end

# Update log joint probability
global_vi = last(DynamicPPL.evaluate!!(model, global_vi, DynamicPPL.DefaultContext()))

return nothing, global_vi
end
114 changes: 114 additions & 0 deletions test/mcmc/gibbs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -882,6 +882,120 @@ end
sampler = Gibbs(:w => HMC(0.05, 10))
@test (sample(model, sampler, 10); true)
end

@testset "GibbsConditional" begin
# Test with the inverse gamma example from the issue
@model function inverse_gdemo(x)
λ ~ Gamma(2, 3)
m ~ Normal(0, sqrt(1 / λ))
for i in 1:length(x)
x[i] ~ Normal(m, sqrt(1 / λ))
end
end

# Define analytical conditionals
function cond_λ(c::NamedTuple)
a = 2.0
b = 3.0
m = c.m
x = c.x
n = length(x)
a_new = a + (n + 1) / 2
b_new = b + sum((x[i] - m)^2 for i in 1:n) / 2 + m^2 / 2
return Gamma(a_new, 1 / b_new)
end

function cond_m(c::NamedTuple)
λ = c.λ
x = c.x
n = length(x)
m_mean = sum(x) / (n + 1)
m_var = 1 / (λ * (n + 1))
return Normal(m_mean, sqrt(m_var))
end

# Test basic functionality
@testset "basic sampling" begin
Random.seed!(42)
x_obs = [1.0, 2.0, 3.0, 2.5, 1.5]
model = inverse_gdemo(x_obs)

# Test that GibbsConditional works
sampler = Gibbs(GibbsConditional(:λ, cond_λ), GibbsConditional(:m, cond_m))
chain = sample(model, sampler, 1000)

# Check that we got the expected variables
@test :λ in names(chain)
@test :m in names(chain)

# Check that the values are reasonable
λ_samples = vec(chain[:λ])
m_samples = vec(chain[:m])

# Given the observed data, we expect certain behavior
@test mean(λ_samples) > 0 # λ should be positive
@test minimum(λ_samples) > 0
@test std(m_samples) < 2.0 # m should be relatively well-constrained
end

# Test mixing with other samplers
@testset "mixed samplers" begin
Random.seed!(42)
x_obs = [1.0, 2.0, 3.0]
model = inverse_gdemo(x_obs)

# Mix GibbsConditional with standard samplers
sampler = Gibbs(GibbsConditional(:λ, cond_λ), :m => MH())
chain = sample(model, sampler, 500)

@test :λ in names(chain)
@test :m in names(chain)
@test size(chain, 1) == 500
end

# Test with a simpler model
@testset "simple normal model" begin
@model function simple_normal(x)
μ ~ Normal(0, 10)
σ ~ truncated(Normal(1, 1); lower=0.01)
for i in 1:length(x)
x[i] ~ Normal(μ, σ)
end
end

# Conditional for μ given σ and x
function cond_μ(c::NamedTuple)
σ = c.σ
x = c.x
n = length(x)
# Prior: μ ~ Normal(0, 10)
# Likelihood: x[i] ~ Normal(μ, σ)
# Posterior: μ ~ Normal(μ_post, σ_post)
prior_var = 100.0 # 10^2
likelihood_var = σ^2 / n
post_var = 1 / (1 / prior_var + n / σ^2)
post_mean = post_var * (0 / prior_var + sum(x) / σ^2)
return Normal(post_mean, sqrt(post_var))
end

Random.seed!(42)
x_obs = randn(10) .+ 2.0 # Data centered around 2
model = simple_normal(x_obs)

sampler = Gibbs(GibbsConditional(:μ, cond_μ), :σ => MH())

chain = sample(model, sampler, 1000)

μ_samples = vec(chain[:μ])
@test abs(mean(μ_samples) - 2.0) < 0.5 # Should be close to true mean
end

# Test that GibbsConditional is marked as a valid component
@testset "isgibbscomponent" begin
gc = GibbsConditional(:x, c -> Normal(0, 1))
@test Turing.Inference.isgibbscomponent(gc)
end
end
end

end
Loading