Skip to content

Gibbs sampler #2647

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/mcmc/Inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ export InferenceAlgorithm,
ESS,
Emcee,
Gibbs, # classic sampling
GibbsConditional, # conditional sampling
HMC,
SGLD,
PolynomialStepsize,
Expand Down Expand Up @@ -392,6 +393,7 @@ include("mh.jl")
include("is.jl")
include("particle_mcmc.jl")
include("gibbs.jl")
include("gibbs_conditional.jl")
include("sghmc.jl")
include("emcee.jl")
include("prior.jl")
Expand Down
173 changes: 173 additions & 0 deletions src/mcmc/gibbs_conditional.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
using DynamicPPL: VarName
using Random: Random
import AbstractMCMC

# These functions provide specialized methods for GibbsConditional that extend the generic implementations in gibbs.jl

"""
GibbsConditional(sym::Symbol, conditional)

A Gibbs sampler component that samples a variable according to a user-provided
analytical conditional distribution.

The `conditional` function should take a `NamedTuple` of conditioned variables and return
a `Distribution` from which to sample the variable `sym`.

# Examples

```julia
# Define a model
@model function inverse_gdemo(x)
λ ~ Gamma(2, 3)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comparing to https://github.com/TuringLang/Turing.jl/blob/v0.35.5/src/mcmc/gibbs_conditional.jl, should the distribution be Gamma(2, inv(3))?

m ~ Normal(0, sqrt(1 / λ))
for i in 1:length(x)
x[i] ~ Normal(m, sqrt(1 / λ))
end
end

# Define analytical conditionals
function cond_λ(c::NamedTuple)
a = 2.0
b = 3.0
m = c.m
x = c.x
n = length(x)
a_new = a + (n + 1) / 2
b_new = b + sum((x[i] - m)^2 for i in 1:n) / 2 + m^2 / 2
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Likewise comparing to https://github.com/TuringLang/Turing.jl/blob/v0.35.5/src/mcmc/gibbs_conditional.jl, should the m in the variance term rather be the mean of x?

return Gamma(a_new, 1 / b_new)
end

function cond_m(c::NamedTuple)
λ = c.λ
x = c.x
n = length(x)
m_mean = sum(x) / (n + 1)
m_var = 1 / (λ * (n + 1))
return Normal(m_mean, sqrt(m_var))
end

# Sample using GibbsConditional
model = inverse_gdemo([1.0, 2.0, 3.0])
chain = sample(model, Gibbs(
:λ => GibbsConditional(:λ, cond_λ),
:m => GibbsConditional(:m, cond_m)
), 1000)
```
"""
struct GibbsConditional{C} <: InferenceAlgorithm
conditional::C

function GibbsConditional(sym::Symbol, conditional::C) where {C}
return new{C}(conditional)
end
end

# Mark GibbsConditional as a valid Gibbs component
isgibbscomponent(::GibbsConditional) = true

# Required methods for Gibbs constructor
Base.length(::GibbsConditional) = 1 # Each GibbsConditional handles one variable

"""
find_global_varinfo(context, fallback_vi)

Traverse the context stack to find global variable information from
GibbsContext, ConditionContext, FixedContext, etc.
"""
function find_global_varinfo(context, fallback_vi)
# Start with the given context and traverse down
current_context = context

while current_context !== nothing
if current_context isa GibbsContext
# Found GibbsContext, return its global varinfo
return get_global_varinfo(current_context)
elseif hasproperty(current_context, :childcontext) &&
isdefined(DynamicPPL, :childcontext)
# Move to child context if it exists
current_context = DynamicPPL.childcontext(current_context)
else
# No more child contexts
break
end
end

# If no GibbsContext found, use the fallback
return fallback_vi
end

"""
DynamicPPL.initialstep(rng, model, sampler::GibbsConditional, vi)

Initialize the GibbsConditional sampler.
"""
function DynamicPPL.initialstep(
rng::Random.AbstractRNG,
model::DynamicPPL.Model,
sampler::DynamicPPL.Sampler{<:GibbsConditional},
vi::DynamicPPL.AbstractVarInfo;
kwargs...,
)
# GibbsConditional doesn't need any special initialization
# Just return the initial state
return nothing, vi
end

"""
AbstractMCMC.step(rng, model, sampler::GibbsConditional, state)

Perform a step of GibbsConditional sampling.
"""
function AbstractMCMC.step(
rng::Random.AbstractRNG,
model::DynamicPPL.Model,
sampler::DynamicPPL.Sampler{<:GibbsConditional{S}},
state::DynamicPPL.AbstractVarInfo;
kwargs...,
) where {S}
alg = sampler.alg

# For GibbsConditional within Gibbs, we need to get all variable values
# Traverse the context stack to find all conditioned/fixed/Gibbs variables
global_vi = if hasproperty(model, :context)
find_global_varinfo(model.context, state)
else
state
end

# Extract conditioned values as a NamedTuple
# Include both random variables and observed data
condvals_vars = DynamicPPL.values_as(DynamicPPL.invlink(global_vi, model), NamedTuple)
condvals_obs = NamedTuple{keys(model.args)}(model.args)
condvals = merge(condvals_vars, condvals_obs)

# Get the conditional distribution
conddist = alg.conditional(condvals)

# Sample from the conditional distribution
updated = rand(rng, conddist)

# Update the variable in state
# The Gibbs sampler ensures that state only contains one variable
# Get the variable name from the keys
varname = first(keys(state))
new_vi = DynamicPPL.setindex!!(state, updated, varname)

return nothing, new_vi
end

"""
setparams_varinfo!!(model, sampler::GibbsConditional, state, params::AbstractVarInfo)

Update the variable info with new parameters for GibbsConditional.
"""
function setparams_varinfo!!(
model::DynamicPPL.Model,
sampler::DynamicPPL.Sampler{<:GibbsConditional},
state,
params::DynamicPPL.AbstractVarInfo,
)
# For GibbsConditional, we just return the params as-is since
# the state is nothing and we don't need to update anything
return params
end
114 changes: 114 additions & 0 deletions test/mcmc/gibbs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -882,6 +882,120 @@ end
sampler = Gibbs(:w => HMC(0.05, 10))
@test (sample(model, sampler, 10); true)
end

@testset "GibbsConditional" begin
# Test with the inverse gamma example from the issue
@model function inverse_gdemo(x)
λ ~ Gamma(2, 3)
m ~ Normal(0, sqrt(1 / λ))
for i in 1:length(x)
x[i] ~ Normal(m, sqrt(1 / λ))
end
end

# Define analytical conditionals
function cond_λ(c::NamedTuple)
a = 2.0
b = 3.0
m = c.m
x = c.x
n = length(x)
a_new = a + (n + 1) / 2
b_new = b + sum((x[i] - m)^2 for i in 1:n) / 2 + m^2 / 2
return Gamma(a_new, 1 / b_new)
end

function cond_m(c::NamedTuple)
λ = c.λ
x = c.x
n = length(x)
m_mean = sum(x) / (n + 1)
m_var = 1 / (λ * (n + 1))
return Normal(m_mean, sqrt(m_var))
end

# Test basic functionality
@testset "basic sampling" begin
Random.seed!(42)
x_obs = [1.0, 2.0, 3.0, 2.5, 1.5]
model = inverse_gdemo(x_obs)

# Test that GibbsConditional works
sampler = Gibbs(GibbsConditional(:λ, cond_λ), GibbsConditional(:m, cond_m))
chain = sample(model, sampler, 1000)

# Check that we got the expected variables
@test :λ in names(chain)
@test :m in names(chain)

# Check that the values are reasonable
λ_samples = vec(chain[:λ])
m_samples = vec(chain[:m])

# Given the observed data, we expect certain behavior
@test mean(λ_samples) > 0 # λ should be positive
@test minimum(λ_samples) > 0
@test std(m_samples) < 2.0 # m should be relatively well-constrained
end

# Test mixing with other samplers
@testset "mixed samplers" begin
Random.seed!(42)
x_obs = [1.0, 2.0, 3.0]
model = inverse_gdemo(x_obs)

# Mix GibbsConditional with standard samplers
sampler = Gibbs(GibbsConditional(:λ, cond_λ), :m => MH())
chain = sample(model, sampler, 500)

@test :λ in names(chain)
@test :m in names(chain)
@test size(chain, 1) == 500
end

# Test with a simpler model
@testset "simple normal model" begin
@model function simple_normal(x)
μ ~ Normal(0, 10)
σ ~ truncated(Normal(1, 1); lower=0.01)
for i in 1:length(x)
x[i] ~ Normal(μ, σ)
end
end

# Conditional for μ given σ and x
function cond_μ(c::NamedTuple)
σ = c.σ
x = c.x
n = length(x)
# Prior: μ ~ Normal(0, 10)
# Likelihood: x[i] ~ Normal(μ, σ)
# Posterior: μ ~ Normal(μ_post, σ_post)
prior_var = 100.0 # 10^2
likelihood_var = σ^2 / n
post_var = 1 / (1 / prior_var + n / σ^2)
post_mean = post_var * (0 / prior_var + sum(x) / σ^2)
return Normal(post_mean, sqrt(post_var))
end

Random.seed!(42)
x_obs = randn(10) .+ 2.0 # Data centered around 2
model = simple_normal(x_obs)

sampler = Gibbs(GibbsConditional(:μ, cond_μ), :σ => MH())

chain = sample(model, sampler, 1000)

μ_samples = vec(chain[:μ])
@test abs(mean(μ_samples) - 2.0) < 0.5 # Should be close to true mean
end

# Test that GibbsConditional is marked as a valid component
@testset "isgibbscomponent" begin
gc = GibbsConditional(:x, c -> Normal(0, 1))
@test Turing.Inference.isgibbscomponent(gc)
end
end
end

end
78 changes: 78 additions & 0 deletions test_gibbs_conditional.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
using Turing
using Turing.Inference: GibbsConditional
using Distributions
using Random
using Statistics

# Test with the inverse gamma example from the issue
@model function inverse_gdemo(x)
λ ~ Gamma(2, 3)
m ~ Normal(0, sqrt(1 / λ))
for i in 1:length(x)
x[i] ~ Normal(m, sqrt(1 / λ))
end
end

# Define analytical conditionals
function cond_λ(c::NamedTuple)
a = 2.0
b = 3.0
m = c.m
x = c.x
n = length(x)
a_new = a + (n + 1) / 2
b_new = b + sum((x[i] - m)^2 for i in 1:n) / 2 + m^2 / 2
return Gamma(a_new, 1 / b_new)
end

function cond_m(c::NamedTuple)
λ = c.λ
x = c.x
n = length(x)
m_mean = sum(x) / (n + 1)
m_var = 1 / (λ * (n + 1))
return Normal(m_mean, sqrt(m_var))
end

# Generate some observed data
Random.seed!(42)
x_obs = [1.0, 2.0, 3.0, 2.5, 1.5]

# Create the model
model = inverse_gdemo(x_obs)

# Sample using GibbsConditional
println("Testing GibbsConditional sampler...")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You probably want various @test calls rather than a lot of prints.

sampler = Gibbs(:λ => GibbsConditional(:λ, cond_λ), :m => GibbsConditional(:m, cond_m))
Copy link
Preview

Copilot AI Aug 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] The variable name is redundantly specified in both the Gibbs pair key and GibbsConditional constructor. Consider if this duplication is necessary or if the API could be simplified.

Suggested change
sampler = Gibbs( => GibbsConditional(, cond_λ), :m => GibbsConditional(:m, cond_m))
sampler = Gibbs( => GibbsConditional(cond_λ), :m => GibbsConditional(cond_m))

Copilot uses AI. Check for mistakes.


# Run a short chain to test
chain = sample(model, sampler, 100)

println("Sampling completed successfully!")
println("\nChain summary:")
println(chain)

# Extract samples
λ_samples = vec(chain[:λ])
m_samples = vec(chain[:m])

println("\nλ statistics:")
println(" Mean: ", mean(λ_samples))
println(" Std: ", std(λ_samples))
println(" Min: ", minimum(λ_samples))
println(" Max: ", maximum(λ_samples))

println("\nm statistics:")
println(" Mean: ", mean(m_samples))
println(" Std: ", std(m_samples))
println(" Min: ", minimum(m_samples))
println(" Max: ", maximum(m_samples))

# Test mixing with other samplers
println("\n\nTesting mixed samplers...")
sampler2 = Gibbs(:λ => GibbsConditional(:λ, cond_λ), :m => MH())

chain2 = sample(model, sampler2, 100)
println("Mixed sampling completed successfully!")
println("\nMixed chain summary:")
println(chain2)
Loading