diff --git a/src/mcmc/Inference.jl b/src/mcmc/Inference.jl index 0951026aa..35cdc46b5 100644 --- a/src/mcmc/Inference.jl +++ b/src/mcmc/Inference.jl @@ -67,6 +67,7 @@ export InferenceAlgorithm, ESS, Emcee, Gibbs, # classic sampling + GibbsConditional, # conditional sampling HMC, SGLD, PolynomialStepsize, @@ -392,6 +393,7 @@ include("mh.jl") include("is.jl") include("particle_mcmc.jl") include("gibbs.jl") +include("gibbs_conditional.jl") include("sghmc.jl") include("emcee.jl") include("prior.jl") diff --git a/src/mcmc/gibbs_conditional.jl b/src/mcmc/gibbs_conditional.jl new file mode 100644 index 000000000..7415c5f3f --- /dev/null +++ b/src/mcmc/gibbs_conditional.jl @@ -0,0 +1,173 @@ +using DynamicPPL: VarName +using Random: Random +import AbstractMCMC + +# These functions provide specialized methods for GibbsConditional that extend the generic implementations in gibbs.jl + +""" + GibbsConditional(sym::Symbol, conditional) + +A Gibbs sampler component that samples a variable according to a user-provided +analytical conditional distribution. + +The `conditional` function should take a `NamedTuple` of conditioned variables and return +a `Distribution` from which to sample the variable `sym`. + +# Examples + +```julia +# Define a model +@model function inverse_gdemo(x) + λ ~ Gamma(2, 3) + m ~ Normal(0, sqrt(1 / λ)) + for i in 1:length(x) + x[i] ~ Normal(m, sqrt(1 / λ)) + end +end + +# Define analytical conditionals +function cond_λ(c::NamedTuple) + a = 2.0 + b = 3.0 + m = c.m + x = c.x + n = length(x) + a_new = a + (n + 1) / 2 + b_new = b + sum((x[i] - m)^2 for i in 1:n) / 2 + m^2 / 2 + return Gamma(a_new, 1 / b_new) +end + +function cond_m(c::NamedTuple) + λ = c.λ + x = c.x + n = length(x) + m_mean = sum(x) / (n + 1) + m_var = 1 / (λ * (n + 1)) + return Normal(m_mean, sqrt(m_var)) +end + +# Sample using GibbsConditional +model = inverse_gdemo([1.0, 2.0, 3.0]) +chain = sample(model, Gibbs( + :λ => GibbsConditional(:λ, cond_λ), + :m => GibbsConditional(:m, cond_m) +), 1000) +``` +""" +struct GibbsConditional{C} <: InferenceAlgorithm + conditional::C + + function GibbsConditional(sym::Symbol, conditional::C) where {C} + return new{C}(conditional) + end +end + +# Mark GibbsConditional as a valid Gibbs component +isgibbscomponent(::GibbsConditional) = true + +# Required methods for Gibbs constructor +Base.length(::GibbsConditional) = 1 # Each GibbsConditional handles one variable + +""" + find_global_varinfo(context, fallback_vi) + +Traverse the context stack to find global variable information from +GibbsContext, ConditionContext, FixedContext, etc. +""" +function find_global_varinfo(context, fallback_vi) + # Start with the given context and traverse down + current_context = context + + while current_context !== nothing + if current_context isa GibbsContext + # Found GibbsContext, return its global varinfo + return get_global_varinfo(current_context) + elseif hasproperty(current_context, :childcontext) && + isdefined(DynamicPPL, :childcontext) + # Move to child context if it exists + current_context = DynamicPPL.childcontext(current_context) + else + # No more child contexts + break + end + end + + # If no GibbsContext found, use the fallback + return fallback_vi +end + +""" + DynamicPPL.initialstep(rng, model, sampler::GibbsConditional, vi) + +Initialize the GibbsConditional sampler. +""" +function DynamicPPL.initialstep( + rng::Random.AbstractRNG, + model::DynamicPPL.Model, + sampler::DynamicPPL.Sampler{<:GibbsConditional}, + vi::DynamicPPL.AbstractVarInfo; + kwargs..., +) + # GibbsConditional doesn't need any special initialization + # Just return the initial state + return nothing, vi +end + +""" + AbstractMCMC.step(rng, model, sampler::GibbsConditional, state) + +Perform a step of GibbsConditional sampling. +""" +function AbstractMCMC.step( + rng::Random.AbstractRNG, + model::DynamicPPL.Model, + sampler::DynamicPPL.Sampler{<:GibbsConditional{S}}, + state::DynamicPPL.AbstractVarInfo; + kwargs..., +) where {S} + alg = sampler.alg + + # For GibbsConditional within Gibbs, we need to get all variable values + # Traverse the context stack to find all conditioned/fixed/Gibbs variables + global_vi = if hasproperty(model, :context) + find_global_varinfo(model.context, state) + else + state + end + + # Extract conditioned values as a NamedTuple + # Include both random variables and observed data + condvals_vars = DynamicPPL.values_as(DynamicPPL.invlink(global_vi, model), NamedTuple) + condvals_obs = NamedTuple{keys(model.args)}(model.args) + condvals = merge(condvals_vars, condvals_obs) + + # Get the conditional distribution + conddist = alg.conditional(condvals) + + # Sample from the conditional distribution + updated = rand(rng, conddist) + + # Update the variable in state + # The Gibbs sampler ensures that state only contains one variable + # Get the variable name from the keys + varname = first(keys(state)) + new_vi = DynamicPPL.setindex!!(state, updated, varname) + + return nothing, new_vi +end + +""" + setparams_varinfo!!(model, sampler::GibbsConditional, state, params::AbstractVarInfo) + +Update the variable info with new parameters for GibbsConditional. +""" +function setparams_varinfo!!( + model::DynamicPPL.Model, + sampler::DynamicPPL.Sampler{<:GibbsConditional}, + state, + params::DynamicPPL.AbstractVarInfo, +) + # For GibbsConditional, we just return the params as-is since + # the state is nothing and we don't need to update anything + return params +end diff --git a/test/mcmc/gibbs.jl b/test/mcmc/gibbs.jl index f44a9fefc..a7884bb7e 100644 --- a/test/mcmc/gibbs.jl +++ b/test/mcmc/gibbs.jl @@ -882,6 +882,120 @@ end sampler = Gibbs(:w => HMC(0.05, 10)) @test (sample(model, sampler, 10); true) end + + @testset "GibbsConditional" begin + # Test with the inverse gamma example from the issue + @model function inverse_gdemo(x) + λ ~ Gamma(2, 3) + m ~ Normal(0, sqrt(1 / λ)) + for i in 1:length(x) + x[i] ~ Normal(m, sqrt(1 / λ)) + end + end + + # Define analytical conditionals + function cond_λ(c::NamedTuple) + a = 2.0 + b = 3.0 + m = c.m + x = c.x + n = length(x) + a_new = a + (n + 1) / 2 + b_new = b + sum((x[i] - m)^2 for i in 1:n) / 2 + m^2 / 2 + return Gamma(a_new, 1 / b_new) + end + + function cond_m(c::NamedTuple) + λ = c.λ + x = c.x + n = length(x) + m_mean = sum(x) / (n + 1) + m_var = 1 / (λ * (n + 1)) + return Normal(m_mean, sqrt(m_var)) + end + + # Test basic functionality + @testset "basic sampling" begin + Random.seed!(42) + x_obs = [1.0, 2.0, 3.0, 2.5, 1.5] + model = inverse_gdemo(x_obs) + + # Test that GibbsConditional works + sampler = Gibbs(GibbsConditional(:λ, cond_λ), GibbsConditional(:m, cond_m)) + chain = sample(model, sampler, 1000) + + # Check that we got the expected variables + @test :λ in names(chain) + @test :m in names(chain) + + # Check that the values are reasonable + λ_samples = vec(chain[:λ]) + m_samples = vec(chain[:m]) + + # Given the observed data, we expect certain behavior + @test mean(λ_samples) > 0 # λ should be positive + @test minimum(λ_samples) > 0 + @test std(m_samples) < 2.0 # m should be relatively well-constrained + end + + # Test mixing with other samplers + @testset "mixed samplers" begin + Random.seed!(42) + x_obs = [1.0, 2.0, 3.0] + model = inverse_gdemo(x_obs) + + # Mix GibbsConditional with standard samplers + sampler = Gibbs(GibbsConditional(:λ, cond_λ), :m => MH()) + chain = sample(model, sampler, 500) + + @test :λ in names(chain) + @test :m in names(chain) + @test size(chain, 1) == 500 + end + + # Test with a simpler model + @testset "simple normal model" begin + @model function simple_normal(x) + μ ~ Normal(0, 10) + σ ~ truncated(Normal(1, 1); lower=0.01) + for i in 1:length(x) + x[i] ~ Normal(μ, σ) + end + end + + # Conditional for μ given σ and x + function cond_μ(c::NamedTuple) + σ = c.σ + x = c.x + n = length(x) + # Prior: μ ~ Normal(0, 10) + # Likelihood: x[i] ~ Normal(μ, σ) + # Posterior: μ ~ Normal(μ_post, σ_post) + prior_var = 100.0 # 10^2 + likelihood_var = σ^2 / n + post_var = 1 / (1 / prior_var + n / σ^2) + post_mean = post_var * (0 / prior_var + sum(x) / σ^2) + return Normal(post_mean, sqrt(post_var)) + end + + Random.seed!(42) + x_obs = randn(10) .+ 2.0 # Data centered around 2 + model = simple_normal(x_obs) + + sampler = Gibbs(GibbsConditional(:μ, cond_μ), :σ => MH()) + + chain = sample(model, sampler, 1000) + + μ_samples = vec(chain[:μ]) + @test abs(mean(μ_samples) - 2.0) < 0.5 # Should be close to true mean + end + + # Test that GibbsConditional is marked as a valid component + @testset "isgibbscomponent" begin + gc = GibbsConditional(:x, c -> Normal(0, 1)) + @test Turing.Inference.isgibbscomponent(gc) + end + end end end diff --git a/test_gibbs_conditional.jl b/test_gibbs_conditional.jl new file mode 100644 index 000000000..d6466e537 --- /dev/null +++ b/test_gibbs_conditional.jl @@ -0,0 +1,78 @@ +using Turing +using Turing.Inference: GibbsConditional +using Distributions +using Random +using Statistics + +# Test with the inverse gamma example from the issue +@model function inverse_gdemo(x) + λ ~ Gamma(2, 3) + m ~ Normal(0, sqrt(1 / λ)) + for i in 1:length(x) + x[i] ~ Normal(m, sqrt(1 / λ)) + end +end + +# Define analytical conditionals +function cond_λ(c::NamedTuple) + a = 2.0 + b = 3.0 + m = c.m + x = c.x + n = length(x) + a_new = a + (n + 1) / 2 + b_new = b + sum((x[i] - m)^2 for i in 1:n) / 2 + m^2 / 2 + return Gamma(a_new, 1 / b_new) +end + +function cond_m(c::NamedTuple) + λ = c.λ + x = c.x + n = length(x) + m_mean = sum(x) / (n + 1) + m_var = 1 / (λ * (n + 1)) + return Normal(m_mean, sqrt(m_var)) +end + +# Generate some observed data +Random.seed!(42) +x_obs = [1.0, 2.0, 3.0, 2.5, 1.5] + +# Create the model +model = inverse_gdemo(x_obs) + +# Sample using GibbsConditional +println("Testing GibbsConditional sampler...") +sampler = Gibbs(:λ => GibbsConditional(:λ, cond_λ), :m => GibbsConditional(:m, cond_m)) + +# Run a short chain to test +chain = sample(model, sampler, 100) + +println("Sampling completed successfully!") +println("\nChain summary:") +println(chain) + +# Extract samples +λ_samples = vec(chain[:λ]) +m_samples = vec(chain[:m]) + +println("\nλ statistics:") +println(" Mean: ", mean(λ_samples)) +println(" Std: ", std(λ_samples)) +println(" Min: ", minimum(λ_samples)) +println(" Max: ", maximum(λ_samples)) + +println("\nm statistics:") +println(" Mean: ", mean(m_samples)) +println(" Std: ", std(m_samples)) +println(" Min: ", minimum(m_samples)) +println(" Max: ", maximum(m_samples)) + +# Test mixing with other samplers +println("\n\nTesting mixed samplers...") +sampler2 = Gibbs(:λ => GibbsConditional(:λ, cond_λ), :m => MH()) + +chain2 = sample(model, sampler2, 100) +println("Mixed sampling completed successfully!") +println("\nMixed chain summary:") +println(chain2)