From c0158ead74a35b620781ae691f73e1c98550e5c6 Mon Sep 17 00:00:00 2001 From: AoifeHughes Date: Thu, 7 Aug 2025 09:08:25 +0100 Subject: [PATCH 1/5] Add GibbsConditional sampler and corresponding tests --- src/mcmc/Inference.jl | 2 + src/mcmc/gibbs_conditional.jl | 245 ++++++++++++++++++++++++++++++++++ test/mcmc/gibbs.jl | 114 ++++++++++++++++ test_gibbs_conditional.jl | 78 +++++++++++ 4 files changed, 439 insertions(+) create mode 100644 src/mcmc/gibbs_conditional.jl create mode 100644 test_gibbs_conditional.jl diff --git a/src/mcmc/Inference.jl b/src/mcmc/Inference.jl index 0951026aa..35cdc46b5 100644 --- a/src/mcmc/Inference.jl +++ b/src/mcmc/Inference.jl @@ -67,6 +67,7 @@ export InferenceAlgorithm, ESS, Emcee, Gibbs, # classic sampling + GibbsConditional, # conditional sampling HMC, SGLD, PolynomialStepsize, @@ -392,6 +393,7 @@ include("mh.jl") include("is.jl") include("particle_mcmc.jl") include("gibbs.jl") +include("gibbs_conditional.jl") include("sghmc.jl") include("emcee.jl") include("prior.jl") diff --git a/src/mcmc/gibbs_conditional.jl b/src/mcmc/gibbs_conditional.jl new file mode 100644 index 000000000..e01e17aff --- /dev/null +++ b/src/mcmc/gibbs_conditional.jl @@ -0,0 +1,245 @@ +using DynamicPPL: VarName +using Random: Random +import AbstractMCMC + +# These functions are defined in gibbs.jl which is loaded before this file + +""" + GibbsConditional(sym::Symbol, conditional) + +A Gibbs sampler component that samples a variable according to a user-provided +analytical conditional distribution. + +The `conditional` function should take a `NamedTuple` of conditioned variables and return +a `Distribution` from which to sample the variable `sym`. + +# Examples + +```julia +# Define a model +@model function inverse_gdemo(x) + λ ~ Gamma(2, 3) + m ~ Normal(0, sqrt(1 / λ)) + for i in 1:length(x) + x[i] ~ Normal(m, sqrt(1 / λ)) + end +end + +# Define analytical conditionals +function cond_λ(c::NamedTuple) + a = 2.0 + b = 3.0 + m = c.m + x = c.x + n = length(x) + a_new = a + (n + 1) / 2 + b_new = b + sum((x[i] - m)^2 for i in 1:n) / 2 + m^2 / 2 + return Gamma(a_new, 1 / b_new) +end + +function cond_m(c::NamedTuple) + λ = c.λ + x = c.x + n = length(x) + m_mean = sum(x) / (n + 1) + m_var = 1 / (λ * (n + 1)) + return Normal(m_mean, sqrt(m_var)) +end + +# Sample using GibbsConditional +model = inverse_gdemo([1.0, 2.0, 3.0]) +chain = sample(model, Gibbs( + :λ => GibbsConditional(:λ, cond_λ), + :m => GibbsConditional(:m, cond_m) +), 1000) +``` +""" +struct GibbsConditional{S,C} <: InferenceAlgorithm + conditional::C + + function GibbsConditional(sym::Symbol, conditional::C) where {C} + return new{sym,C}(conditional) + end +end + +# Mark GibbsConditional as a valid Gibbs component +isgibbscomponent(::GibbsConditional) = true + +""" + DynamicPPL.initialstep(rng, model, sampler::GibbsConditional, vi) + +Initialize the GibbsConditional sampler. +""" +function DynamicPPL.initialstep( + rng::Random.AbstractRNG, + model::DynamicPPL.Model, + sampler::DynamicPPL.Sampler{<:GibbsConditional}, + vi::DynamicPPL.AbstractVarInfo; + kwargs..., +) + # GibbsConditional doesn't need any special initialization + # Just return the initial state + return nothing, vi +end + +""" + AbstractMCMC.step(rng, model, sampler::GibbsConditional, state) + +Perform a step of GibbsConditional sampling. +""" +function AbstractMCMC.step( + rng::Random.AbstractRNG, + model::DynamicPPL.Model, + sampler::DynamicPPL.Sampler{<:GibbsConditional{S}}, + state::DynamicPPL.AbstractVarInfo; + kwargs..., +) where {S} + alg = sampler.alg + + # For GibbsConditional within Gibbs, we need to get all variable values + # Check if we're in a Gibbs context + global_vi = if hasproperty(model, :context) && model.context isa GibbsContext + # We're in a Gibbs context, get the global varinfo + get_global_varinfo(model.context) + else + # We're not in a Gibbs context, use the current state + state + end + + # Extract conditioned values as a NamedTuple + # Include both random variables and observed data + condvals_vars = DynamicPPL.values_as(DynamicPPL.invlink(global_vi, model), NamedTuple) + condvals_obs = NamedTuple{keys(model.args)}(model.args) + condvals = merge(condvals_vars, condvals_obs) + + # Get the conditional distribution + conddist = alg.conditional(condvals) + + # Sample from the conditional distribution + updated = rand(rng, conddist) + + # Update the variable in state + # We need to get the actual VarName for this variable + # The symbol S tells us which variable to update + vn = VarName{S}() + + # Check if the variable needs to be a vector + new_vi = if haskey(state, vn) + # Update the existing variable + DynamicPPL.setindex!!(state, updated, vn) + else + # Try to find the variable with indices + # This handles cases where the variable might have indices + local updated_vi = state + found = false + for key in keys(state) + if DynamicPPL.getsym(key) == S + updated_vi = DynamicPPL.setindex!!(state, updated, key) + found = true + break + end + end + if !found + error("Could not find variable $S in VarInfo") + end + updated_vi + end + + # Update log joint probability + new_vi = last(DynamicPPL.evaluate!!(model, new_vi, DynamicPPL.DefaultContext())) + + return nothing, new_vi +end + +""" + setparams_varinfo!!(model, sampler::GibbsConditional, state, params::AbstractVarInfo) + +Update the variable info with new parameters for GibbsConditional. +""" +function setparams_varinfo!!( + model::DynamicPPL.Model, + sampler::DynamicPPL.Sampler{<:GibbsConditional}, + state, + params::DynamicPPL.AbstractVarInfo, +) + # For GibbsConditional, we just return the params as-is since + # the state is nothing and we don't need to update anything + return params +end + +""" + gibbs_initialstep_recursive( + rng, model, sampler::GibbsConditional, target_varnames, global_vi, prev_state + ) + +Initialize the GibbsConditional sampler. +""" +function gibbs_initialstep_recursive( + rng::Random.AbstractRNG, + model::DynamicPPL.Model, + sampler_wrapped::DynamicPPL.Sampler{<:GibbsConditional}, + target_varnames::AbstractVector{<:VarName}, + global_vi::DynamicPPL.AbstractVarInfo, + prev_state, +) + # GibbsConditional doesn't need any special initialization + # Just perform one sampling step + return gibbs_step_recursive( + rng, model, sampler_wrapped, target_varnames, global_vi, nothing + ) +end + +""" + gibbs_step_recursive( + rng, model, sampler::GibbsConditional, target_varnames, global_vi, state + ) + +Perform a single step of GibbsConditional sampling. +""" +function gibbs_step_recursive( + rng::Random.AbstractRNG, + model::DynamicPPL.Model, + sampler_wrapped::DynamicPPL.Sampler{<:GibbsConditional{S}}, + target_varnames::AbstractVector{<:VarName}, + global_vi::DynamicPPL.AbstractVarInfo, + state, +) where {S} + sampler = sampler_wrapped.alg + + # Extract conditioned values as a NamedTuple + # Include both random variables and observed data + condvals_vars = DynamicPPL.values_as(DynamicPPL.invlink(global_vi, model), NamedTuple) + condvals_obs = NamedTuple{keys(model.args)}(model.args) + condvals = merge(condvals_vars, condvals_obs) + + # Get the conditional distribution + conddist = sampler.conditional(condvals) + + # Sample from the conditional distribution + updated = rand(rng, conddist) + + # Update the variable in global_vi + # We need to get the actual VarName for this variable + # The symbol S tells us which variable to update + vn = VarName{S}() + + # Check if the variable needs to be a vector + if haskey(global_vi, vn) + # Update the existing variable + global_vi = DynamicPPL.setindex!!(global_vi, updated, vn) + else + # Try to find the variable with indices + # This handles cases where the variable might have indices + for key in keys(global_vi) + if DynamicPPL.getsym(key) == S + global_vi = DynamicPPL.setindex!!(global_vi, updated, key) + break + end + end + end + + # Update log joint probability + global_vi = last(DynamicPPL.evaluate!!(model, global_vi, DynamicPPL.DefaultContext())) + + return nothing, global_vi +end diff --git a/test/mcmc/gibbs.jl b/test/mcmc/gibbs.jl index f44a9fefc..a7884bb7e 100644 --- a/test/mcmc/gibbs.jl +++ b/test/mcmc/gibbs.jl @@ -882,6 +882,120 @@ end sampler = Gibbs(:w => HMC(0.05, 10)) @test (sample(model, sampler, 10); true) end + + @testset "GibbsConditional" begin + # Test with the inverse gamma example from the issue + @model function inverse_gdemo(x) + λ ~ Gamma(2, 3) + m ~ Normal(0, sqrt(1 / λ)) + for i in 1:length(x) + x[i] ~ Normal(m, sqrt(1 / λ)) + end + end + + # Define analytical conditionals + function cond_λ(c::NamedTuple) + a = 2.0 + b = 3.0 + m = c.m + x = c.x + n = length(x) + a_new = a + (n + 1) / 2 + b_new = b + sum((x[i] - m)^2 for i in 1:n) / 2 + m^2 / 2 + return Gamma(a_new, 1 / b_new) + end + + function cond_m(c::NamedTuple) + λ = c.λ + x = c.x + n = length(x) + m_mean = sum(x) / (n + 1) + m_var = 1 / (λ * (n + 1)) + return Normal(m_mean, sqrt(m_var)) + end + + # Test basic functionality + @testset "basic sampling" begin + Random.seed!(42) + x_obs = [1.0, 2.0, 3.0, 2.5, 1.5] + model = inverse_gdemo(x_obs) + + # Test that GibbsConditional works + sampler = Gibbs(GibbsConditional(:λ, cond_λ), GibbsConditional(:m, cond_m)) + chain = sample(model, sampler, 1000) + + # Check that we got the expected variables + @test :λ in names(chain) + @test :m in names(chain) + + # Check that the values are reasonable + λ_samples = vec(chain[:λ]) + m_samples = vec(chain[:m]) + + # Given the observed data, we expect certain behavior + @test mean(λ_samples) > 0 # λ should be positive + @test minimum(λ_samples) > 0 + @test std(m_samples) < 2.0 # m should be relatively well-constrained + end + + # Test mixing with other samplers + @testset "mixed samplers" begin + Random.seed!(42) + x_obs = [1.0, 2.0, 3.0] + model = inverse_gdemo(x_obs) + + # Mix GibbsConditional with standard samplers + sampler = Gibbs(GibbsConditional(:λ, cond_λ), :m => MH()) + chain = sample(model, sampler, 500) + + @test :λ in names(chain) + @test :m in names(chain) + @test size(chain, 1) == 500 + end + + # Test with a simpler model + @testset "simple normal model" begin + @model function simple_normal(x) + μ ~ Normal(0, 10) + σ ~ truncated(Normal(1, 1); lower=0.01) + for i in 1:length(x) + x[i] ~ Normal(μ, σ) + end + end + + # Conditional for μ given σ and x + function cond_μ(c::NamedTuple) + σ = c.σ + x = c.x + n = length(x) + # Prior: μ ~ Normal(0, 10) + # Likelihood: x[i] ~ Normal(μ, σ) + # Posterior: μ ~ Normal(μ_post, σ_post) + prior_var = 100.0 # 10^2 + likelihood_var = σ^2 / n + post_var = 1 / (1 / prior_var + n / σ^2) + post_mean = post_var * (0 / prior_var + sum(x) / σ^2) + return Normal(post_mean, sqrt(post_var)) + end + + Random.seed!(42) + x_obs = randn(10) .+ 2.0 # Data centered around 2 + model = simple_normal(x_obs) + + sampler = Gibbs(GibbsConditional(:μ, cond_μ), :σ => MH()) + + chain = sample(model, sampler, 1000) + + μ_samples = vec(chain[:μ]) + @test abs(mean(μ_samples) - 2.0) < 0.5 # Should be close to true mean + end + + # Test that GibbsConditional is marked as a valid component + @testset "isgibbscomponent" begin + gc = GibbsConditional(:x, c -> Normal(0, 1)) + @test Turing.Inference.isgibbscomponent(gc) + end + end end end diff --git a/test_gibbs_conditional.jl b/test_gibbs_conditional.jl new file mode 100644 index 000000000..d6466e537 --- /dev/null +++ b/test_gibbs_conditional.jl @@ -0,0 +1,78 @@ +using Turing +using Turing.Inference: GibbsConditional +using Distributions +using Random +using Statistics + +# Test with the inverse gamma example from the issue +@model function inverse_gdemo(x) + λ ~ Gamma(2, 3) + m ~ Normal(0, sqrt(1 / λ)) + for i in 1:length(x) + x[i] ~ Normal(m, sqrt(1 / λ)) + end +end + +# Define analytical conditionals +function cond_λ(c::NamedTuple) + a = 2.0 + b = 3.0 + m = c.m + x = c.x + n = length(x) + a_new = a + (n + 1) / 2 + b_new = b + sum((x[i] - m)^2 for i in 1:n) / 2 + m^2 / 2 + return Gamma(a_new, 1 / b_new) +end + +function cond_m(c::NamedTuple) + λ = c.λ + x = c.x + n = length(x) + m_mean = sum(x) / (n + 1) + m_var = 1 / (λ * (n + 1)) + return Normal(m_mean, sqrt(m_var)) +end + +# Generate some observed data +Random.seed!(42) +x_obs = [1.0, 2.0, 3.0, 2.5, 1.5] + +# Create the model +model = inverse_gdemo(x_obs) + +# Sample using GibbsConditional +println("Testing GibbsConditional sampler...") +sampler = Gibbs(:λ => GibbsConditional(:λ, cond_λ), :m => GibbsConditional(:m, cond_m)) + +# Run a short chain to test +chain = sample(model, sampler, 100) + +println("Sampling completed successfully!") +println("\nChain summary:") +println(chain) + +# Extract samples +λ_samples = vec(chain[:λ]) +m_samples = vec(chain[:m]) + +println("\nλ statistics:") +println(" Mean: ", mean(λ_samples)) +println(" Std: ", std(λ_samples)) +println(" Min: ", minimum(λ_samples)) +println(" Max: ", maximum(λ_samples)) + +println("\nm statistics:") +println(" Mean: ", mean(m_samples)) +println(" Std: ", std(m_samples)) +println(" Min: ", minimum(m_samples)) +println(" Max: ", maximum(m_samples)) + +# Test mixing with other samplers +println("\n\nTesting mixed samplers...") +sampler2 = Gibbs(:λ => GibbsConditional(:λ, cond_λ), :m => MH()) + +chain2 = sample(model, sampler2, 100) +println("Mixed sampling completed successfully!") +println("\nMixed chain summary:") +println(chain2) From a972b5a2b54216b2dfb5b5ee2fa807229be081bb Mon Sep 17 00:00:00 2001 From: AoifeHughes Date: Thu, 7 Aug 2025 09:13:15 +0100 Subject: [PATCH 2/5] clarified comment --- src/mcmc/gibbs_conditional.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/mcmc/gibbs_conditional.jl b/src/mcmc/gibbs_conditional.jl index e01e17aff..2bf7a7bb5 100644 --- a/src/mcmc/gibbs_conditional.jl +++ b/src/mcmc/gibbs_conditional.jl @@ -2,7 +2,7 @@ using DynamicPPL: VarName using Random: Random import AbstractMCMC -# These functions are defined in gibbs.jl which is loaded before this file +# These functions provide specialized methods for GibbsConditional that extend the generic implementations in gibbs.jl """ GibbsConditional(sym::Symbol, conditional) From c3cc7739cbcae0675ed50bff85dcdd34ddd29c3a Mon Sep 17 00:00:00 2001 From: AoifeHughes Date: Fri, 8 Aug 2025 11:11:32 +0100 Subject: [PATCH 3/5] add MHs suggestions --- src/mcmc/gibbs_conditional.jl | 148 +++++++++------------------------- 1 file changed, 37 insertions(+), 111 deletions(-) diff --git a/src/mcmc/gibbs_conditional.jl b/src/mcmc/gibbs_conditional.jl index 2bf7a7bb5..c2eba05ba 100644 --- a/src/mcmc/gibbs_conditional.jl +++ b/src/mcmc/gibbs_conditional.jl @@ -54,17 +54,45 @@ chain = sample(model, Gibbs( ), 1000) ``` """ -struct GibbsConditional{S,C} <: InferenceAlgorithm +struct GibbsConditional{C} <: InferenceAlgorithm conditional::C function GibbsConditional(sym::Symbol, conditional::C) where {C} - return new{sym,C}(conditional) + return new{C}(conditional) end end # Mark GibbsConditional as a valid Gibbs component isgibbscomponent(::GibbsConditional) = true +""" + find_global_varinfo(context, fallback_vi) + +Traverse the context stack to find global variable information from +GibbsContext, ConditionContext, FixedContext, etc. +""" +function find_global_varinfo(context, fallback_vi) + # Start with the given context and traverse down + current_context = context + + while current_context !== nothing + if current_context isa GibbsContext + # Found GibbsContext, return its global varinfo + return get_global_varinfo(current_context) + elseif hasproperty(current_context, :childcontext) && + isdefined(DynamicPPL, :childcontext) + # Move to child context if it exists + current_context = DynamicPPL.childcontext(current_context) + else + # No more child contexts + break + end + end + + # If no GibbsContext found, use the fallback + return fallback_vi +end + """ DynamicPPL.initialstep(rng, model, sampler::GibbsConditional, vi) @@ -97,12 +125,10 @@ function AbstractMCMC.step( alg = sampler.alg # For GibbsConditional within Gibbs, we need to get all variable values - # Check if we're in a Gibbs context - global_vi = if hasproperty(model, :context) && model.context isa GibbsContext - # We're in a Gibbs context, get the global varinfo - get_global_varinfo(model.context) + # Traverse the context stack to find all conditioned/fixed/Gibbs variables + global_vi = if hasproperty(model, :context) + find_global_varinfo(model.context, state) else - # We're not in a Gibbs context, use the current state state end @@ -119,34 +145,10 @@ function AbstractMCMC.step( updated = rand(rng, conddist) # Update the variable in state - # We need to get the actual VarName for this variable - # The symbol S tells us which variable to update - vn = VarName{S}() - - # Check if the variable needs to be a vector - new_vi = if haskey(state, vn) - # Update the existing variable - DynamicPPL.setindex!!(state, updated, vn) - else - # Try to find the variable with indices - # This handles cases where the variable might have indices - local updated_vi = state - found = false - for key in keys(state) - if DynamicPPL.getsym(key) == S - updated_vi = DynamicPPL.setindex!!(state, updated, key) - found = true - break - end - end - if !found - error("Could not find variable $S in VarInfo") - end - updated_vi - end - - # Update log joint probability - new_vi = last(DynamicPPL.evaluate!!(model, new_vi, DynamicPPL.DefaultContext())) + # The Gibbs sampler ensures that state only contains one variable + # Get the variable name from the keys + varname = first(keys(state)) + new_vi = DynamicPPL.setindex!!(state, updated, varname) return nothing, new_vi end @@ -167,79 +169,3 @@ function setparams_varinfo!!( return params end -""" - gibbs_initialstep_recursive( - rng, model, sampler::GibbsConditional, target_varnames, global_vi, prev_state - ) - -Initialize the GibbsConditional sampler. -""" -function gibbs_initialstep_recursive( - rng::Random.AbstractRNG, - model::DynamicPPL.Model, - sampler_wrapped::DynamicPPL.Sampler{<:GibbsConditional}, - target_varnames::AbstractVector{<:VarName}, - global_vi::DynamicPPL.AbstractVarInfo, - prev_state, -) - # GibbsConditional doesn't need any special initialization - # Just perform one sampling step - return gibbs_step_recursive( - rng, model, sampler_wrapped, target_varnames, global_vi, nothing - ) -end - -""" - gibbs_step_recursive( - rng, model, sampler::GibbsConditional, target_varnames, global_vi, state - ) - -Perform a single step of GibbsConditional sampling. -""" -function gibbs_step_recursive( - rng::Random.AbstractRNG, - model::DynamicPPL.Model, - sampler_wrapped::DynamicPPL.Sampler{<:GibbsConditional{S}}, - target_varnames::AbstractVector{<:VarName}, - global_vi::DynamicPPL.AbstractVarInfo, - state, -) where {S} - sampler = sampler_wrapped.alg - - # Extract conditioned values as a NamedTuple - # Include both random variables and observed data - condvals_vars = DynamicPPL.values_as(DynamicPPL.invlink(global_vi, model), NamedTuple) - condvals_obs = NamedTuple{keys(model.args)}(model.args) - condvals = merge(condvals_vars, condvals_obs) - - # Get the conditional distribution - conddist = sampler.conditional(condvals) - - # Sample from the conditional distribution - updated = rand(rng, conddist) - - # Update the variable in global_vi - # We need to get the actual VarName for this variable - # The symbol S tells us which variable to update - vn = VarName{S}() - - # Check if the variable needs to be a vector - if haskey(global_vi, vn) - # Update the existing variable - global_vi = DynamicPPL.setindex!!(global_vi, updated, vn) - else - # Try to find the variable with indices - # This handles cases where the variable might have indices - for key in keys(global_vi) - if DynamicPPL.getsym(key) == S - global_vi = DynamicPPL.setindex!!(global_vi, updated, key) - break - end - end - end - - # Update log joint probability - global_vi = last(DynamicPPL.evaluate!!(model, global_vi, DynamicPPL.DefaultContext())) - - return nothing, global_vi -end From 714c1e82979e5daa6cb1d005c113c967e9d4647a Mon Sep 17 00:00:00 2001 From: AoifeHughes Date: Fri, 8 Aug 2025 11:11:52 +0100 Subject: [PATCH 4/5] formatter --- src/mcmc/gibbs_conditional.jl | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/mcmc/gibbs_conditional.jl b/src/mcmc/gibbs_conditional.jl index c2eba05ba..fe04b048d 100644 --- a/src/mcmc/gibbs_conditional.jl +++ b/src/mcmc/gibbs_conditional.jl @@ -74,13 +74,13 @@ GibbsContext, ConditionContext, FixedContext, etc. function find_global_varinfo(context, fallback_vi) # Start with the given context and traverse down current_context = context - + while current_context !== nothing if current_context isa GibbsContext # Found GibbsContext, return its global varinfo return get_global_varinfo(current_context) - elseif hasproperty(current_context, :childcontext) && - isdefined(DynamicPPL, :childcontext) + elseif hasproperty(current_context, :childcontext) && + isdefined(DynamicPPL, :childcontext) # Move to child context if it exists current_context = DynamicPPL.childcontext(current_context) else @@ -88,7 +88,7 @@ function find_global_varinfo(context, fallback_vi) break end end - + # If no GibbsContext found, use the fallback return fallback_vi end @@ -168,4 +168,3 @@ function setparams_varinfo!!( # the state is nothing and we don't need to update anything return params end - From 94b723da263927edfef7c20d8e56e543d0d84fc3 Mon Sep 17 00:00:00 2001 From: AoifeHughes Date: Fri, 8 Aug 2025 14:37:27 +0100 Subject: [PATCH 5/5] fixed exporting thing --- src/mcmc/gibbs_conditional.jl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/mcmc/gibbs_conditional.jl b/src/mcmc/gibbs_conditional.jl index fe04b048d..7415c5f3f 100644 --- a/src/mcmc/gibbs_conditional.jl +++ b/src/mcmc/gibbs_conditional.jl @@ -65,6 +65,9 @@ end # Mark GibbsConditional as a valid Gibbs component isgibbscomponent(::GibbsConditional) = true +# Required methods for Gibbs constructor +Base.length(::GibbsConditional) = 1 # Each GibbsConditional handles one variable + """ find_global_varinfo(context, fallback_vi)