Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
c0158ea
Add GibbsConditional sampler and corresponding tests
Aug 7, 2025
a972b5a
clarified comment
Aug 7, 2025
bdb7f73
Merge branch 'main' into gibbs-sampler
mhauru Aug 7, 2025
c3cc773
add MHs suggestions
Aug 8, 2025
714c1e8
formatter
Aug 8, 2025
97c571d
Merge branch 'gibbs-sampler' of github.com:TuringLang/Turing.jl into …
Aug 8, 2025
94b723d
fixed exporting thing
Aug 8, 2025
891ac14
Merge branch 'main' into gibbs-sampler
AoifeHughes Aug 18, 2025
2058ae5
Refactor Gibbs sampler to use inverse of parameters for Gamma distrib…
Sep 23, 2025
b0812a3
removed file added by mistake
Sep 25, 2025
d910312
Add safety checks and error handling in find_global_varinfo and Abstr…
Sep 29, 2025
4b1dc2f
imports?
Oct 9, 2025
1e84309
Merge remote-tracking branch 'origin/main' into gibbs-sampler
mhauru Nov 17, 2025
a33d8a9
Fixes and improvements for GibbsConditional
mhauru Nov 17, 2025
f41fc6e
Move GibbsConditional tests to their own file
mhauru Nov 17, 2025
bd5ff0b
More GibbsConditional tests
mhauru Nov 17, 2025
34acad7
Bump patch version to 0.41.2, add HISTORY.md entry
mhauru Nov 17, 2025
4786a59
Remove spurious change
mhauru Nov 17, 2025
8951d98
Code style and documentation
mhauru Nov 18, 2025
45ab5f8
Add one test_throws, tweak test thresholds and dimensions
mhauru Nov 18, 2025
805bc60
Apply suggestions from code review
mhauru Nov 18, 2025
d0c3cf4
Add links for where to get analytical posteriors
mhauru Nov 18, 2025
98f4213
Update TODO note
mhauru Nov 18, 2025
c74b0a0
Fix a GibbsConditional bug, add a test
mhauru Nov 18, 2025
4a7d08c
Set seeds better
mhauru Nov 18, 2025
744d254
Use getvalue in docstring
mhauru Nov 19, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/mcmc/Inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ export InferenceAlgorithm,
ESS,
Emcee,
Gibbs, # classic sampling
GibbsConditional, # conditional sampling
HMC,
SGLD,
PolynomialStepsize,
Expand Down Expand Up @@ -392,6 +393,7 @@ include("mh.jl")
include("is.jl")
include("particle_mcmc.jl")
include("gibbs.jl")
include("gibbs_conditional.jl")
include("sghmc.jl")
include("emcee.jl")
include("prior.jl")
Expand Down
245 changes: 245 additions & 0 deletions src/mcmc/gibbs_conditional.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,245 @@
using DynamicPPL: VarName
using Random: Random
import AbstractMCMC

# These functions provide specialized methods for GibbsConditional that extend the generic implementations in gibbs.jl

"""
GibbsConditional(sym::Symbol, conditional)

A Gibbs sampler component that samples a variable according to a user-provided
analytical conditional distribution.

The `conditional` function should take a `NamedTuple` of conditioned variables and return
a `Distribution` from which to sample the variable `sym`.

# Examples

```julia
# Define a model
@model function inverse_gdemo(x)
λ ~ Gamma(2, 3)
m ~ Normal(0, sqrt(1 / λ))
for i in 1:length(x)
x[i] ~ Normal(m, sqrt(1 / λ))
end
end

# Define analytical conditionals
function cond_λ(c::NamedTuple)
a = 2.0
b = 3.0
m = c.m
x = c.x
n = length(x)
a_new = a + (n + 1) / 2
b_new = b + sum((x[i] - m)^2 for i in 1:n) / 2 + m^2 / 2
return Gamma(a_new, 1 / b_new)
end

function cond_m(c::NamedTuple)
λ = c.λ
x = c.x
n = length(x)
m_mean = sum(x) / (n + 1)
m_var = 1 / (λ * (n + 1))
return Normal(m_mean, sqrt(m_var))
end

# Sample using GibbsConditional
model = inverse_gdemo([1.0, 2.0, 3.0])
chain = sample(model, Gibbs(
:λ => GibbsConditional(:λ, cond_λ),
:m => GibbsConditional(:m, cond_m)
), 1000)
```
"""
struct GibbsConditional{S,C} <: InferenceAlgorithm
conditional::C

function GibbsConditional(sym::Symbol, conditional::C) where {C}
return new{sym,C}(conditional)
end
end

# Mark GibbsConditional as a valid Gibbs component
isgibbscomponent(::GibbsConditional) = true

"""
DynamicPPL.initialstep(rng, model, sampler::GibbsConditional, vi)

Initialize the GibbsConditional sampler.
"""
function DynamicPPL.initialstep(
rng::Random.AbstractRNG,
model::DynamicPPL.Model,
sampler::DynamicPPL.Sampler{<:GibbsConditional},
vi::DynamicPPL.AbstractVarInfo;
kwargs...,
)
# GibbsConditional doesn't need any special initialization
# Just return the initial state
return nothing, vi
end

"""
AbstractMCMC.step(rng, model, sampler::GibbsConditional, state)

Perform a step of GibbsConditional sampling.
"""
function AbstractMCMC.step(
rng::Random.AbstractRNG,
model::DynamicPPL.Model,
sampler::DynamicPPL.Sampler{<:GibbsConditional{S}},
state::DynamicPPL.AbstractVarInfo;
kwargs...,
) where {S}
alg = sampler.alg

# For GibbsConditional within Gibbs, we need to get all variable values
# Check if we're in a Gibbs context
global_vi = if hasproperty(model, :context) && model.context isa GibbsContext
# We're in a Gibbs context, get the global varinfo
get_global_varinfo(model.context)
else
# We're not in a Gibbs context, use the current state
state
end

# Extract conditioned values as a NamedTuple
# Include both random variables and observed data
condvals_vars = DynamicPPL.values_as(DynamicPPL.invlink(global_vi, model), NamedTuple)
condvals_obs = NamedTuple{keys(model.args)}(model.args)
condvals = merge(condvals_vars, condvals_obs)

# Get the conditional distribution
conddist = alg.conditional(condvals)

# Sample from the conditional distribution
updated = rand(rng, conddist)

# Update the variable in state
# We need to get the actual VarName for this variable
# The symbol S tells us which variable to update
vn = VarName{S}()

# Check if the variable needs to be a vector
new_vi = if haskey(state, vn)
# Update the existing variable
DynamicPPL.setindex!!(state, updated, vn)
else
# Try to find the variable with indices
# This handles cases where the variable might have indices
local updated_vi = state
found = false
for key in keys(state)
if DynamicPPL.getsym(key) == S
updated_vi = DynamicPPL.setindex!!(state, updated, key)
found = true
break
end
end
if !found
error("Could not find variable $S in VarInfo")
end
updated_vi
end

# Update log joint probability
new_vi = last(DynamicPPL.evaluate!!(model, new_vi, DynamicPPL.DefaultContext()))

return nothing, new_vi
end

"""
setparams_varinfo!!(model, sampler::GibbsConditional, state, params::AbstractVarInfo)

Update the variable info with new parameters for GibbsConditional.
"""
function setparams_varinfo!!(
model::DynamicPPL.Model,
sampler::DynamicPPL.Sampler{<:GibbsConditional},
state,
params::DynamicPPL.AbstractVarInfo,
)
# For GibbsConditional, we just return the params as-is since
# the state is nothing and we don't need to update anything
return params
end

"""
gibbs_initialstep_recursive(
rng, model, sampler::GibbsConditional, target_varnames, global_vi, prev_state
)

Initialize the GibbsConditional sampler.
"""
function gibbs_initialstep_recursive(
rng::Random.AbstractRNG,
model::DynamicPPL.Model,
sampler_wrapped::DynamicPPL.Sampler{<:GibbsConditional},
target_varnames::AbstractVector{<:VarName},
global_vi::DynamicPPL.AbstractVarInfo,
prev_state,
)
# GibbsConditional doesn't need any special initialization
# Just perform one sampling step
return gibbs_step_recursive(
rng, model, sampler_wrapped, target_varnames, global_vi, nothing
)
end

"""
gibbs_step_recursive(
rng, model, sampler::GibbsConditional, target_varnames, global_vi, state
)

Perform a single step of GibbsConditional sampling.
"""
function gibbs_step_recursive(
rng::Random.AbstractRNG,
model::DynamicPPL.Model,
sampler_wrapped::DynamicPPL.Sampler{<:GibbsConditional{S}},
target_varnames::AbstractVector{<:VarName},
global_vi::DynamicPPL.AbstractVarInfo,
state,
) where {S}
sampler = sampler_wrapped.alg

# Extract conditioned values as a NamedTuple
# Include both random variables and observed data
condvals_vars = DynamicPPL.values_as(DynamicPPL.invlink(global_vi, model), NamedTuple)
condvals_obs = NamedTuple{keys(model.args)}(model.args)
condvals = merge(condvals_vars, condvals_obs)

# Get the conditional distribution
conddist = sampler.conditional(condvals)

# Sample from the conditional distribution
updated = rand(rng, conddist)

# Update the variable in global_vi
# We need to get the actual VarName for this variable
# The symbol S tells us which variable to update
vn = VarName{S}()

# Check if the variable needs to be a vector
if haskey(global_vi, vn)
# Update the existing variable
global_vi = DynamicPPL.setindex!!(global_vi, updated, vn)
else
# Try to find the variable with indices
# This handles cases where the variable might have indices
for key in keys(global_vi)
if DynamicPPL.getsym(key) == S
global_vi = DynamicPPL.setindex!!(global_vi, updated, key)
break
end
end
end

# Update log joint probability
global_vi = last(DynamicPPL.evaluate!!(model, global_vi, DynamicPPL.DefaultContext()))

return nothing, global_vi
end
114 changes: 114 additions & 0 deletions test/mcmc/gibbs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -882,6 +882,120 @@ end
sampler = Gibbs(:w => HMC(0.05, 10))
@test (sample(model, sampler, 10); true)
end

@testset "GibbsConditional" begin
# Test with the inverse gamma example from the issue
@model function inverse_gdemo(x)
λ ~ Gamma(2, 3)
m ~ Normal(0, sqrt(1 / λ))
for i in 1:length(x)
x[i] ~ Normal(m, sqrt(1 / λ))
end
end

# Define analytical conditionals
function cond_λ(c::NamedTuple)
a = 2.0
b = 3.0
m = c.m
x = c.x
n = length(x)
a_new = a + (n + 1) / 2
b_new = b + sum((x[i] - m)^2 for i in 1:n) / 2 + m^2 / 2
return Gamma(a_new, 1 / b_new)
end

function cond_m(c::NamedTuple)
λ = c.λ
x = c.x
n = length(x)
m_mean = sum(x) / (n + 1)
m_var = 1 / (λ * (n + 1))
return Normal(m_mean, sqrt(m_var))
end

# Test basic functionality
@testset "basic sampling" begin
Random.seed!(42)
x_obs = [1.0, 2.0, 3.0, 2.5, 1.5]
model = inverse_gdemo(x_obs)

# Test that GibbsConditional works
sampler = Gibbs(GibbsConditional(:λ, cond_λ), GibbsConditional(:m, cond_m))
chain = sample(model, sampler, 1000)

# Check that we got the expected variables
@test :λ in names(chain)
@test :m in names(chain)

# Check that the values are reasonable
λ_samples = vec(chain[:λ])
m_samples = vec(chain[:m])

# Given the observed data, we expect certain behavior
@test mean(λ_samples) > 0 # λ should be positive
@test minimum(λ_samples) > 0
@test std(m_samples) < 2.0 # m should be relatively well-constrained
end

# Test mixing with other samplers
@testset "mixed samplers" begin
Random.seed!(42)
x_obs = [1.0, 2.0, 3.0]
model = inverse_gdemo(x_obs)

# Mix GibbsConditional with standard samplers
sampler = Gibbs(GibbsConditional(:λ, cond_λ), :m => MH())
chain = sample(model, sampler, 500)

@test :λ in names(chain)
@test :m in names(chain)
@test size(chain, 1) == 500
end

# Test with a simpler model
@testset "simple normal model" begin
@model function simple_normal(x)
μ ~ Normal(0, 10)
σ ~ truncated(Normal(1, 1); lower=0.01)
for i in 1:length(x)
x[i] ~ Normal(μ, σ)
end
end

# Conditional for μ given σ and x
function cond_μ(c::NamedTuple)
σ = c.σ
x = c.x
n = length(x)
# Prior: μ ~ Normal(0, 10)
# Likelihood: x[i] ~ Normal(μ, σ)
# Posterior: μ ~ Normal(μ_post, σ_post)
prior_var = 100.0 # 10^2
likelihood_var = σ^2 / n
post_var = 1 / (1 / prior_var + n / σ^2)
post_mean = post_var * (0 / prior_var + sum(x) / σ^2)
return Normal(post_mean, sqrt(post_var))
end

Random.seed!(42)
x_obs = randn(10) .+ 2.0 # Data centered around 2
model = simple_normal(x_obs)

sampler = Gibbs(GibbsConditional(:μ, cond_μ), :σ => MH())

chain = sample(model, sampler, 1000)

μ_samples = vec(chain[:μ])
@test abs(mean(μ_samples) - 2.0) < 0.5 # Should be close to true mean
end

# Test that GibbsConditional is marked as a valid component
@testset "isgibbscomponent" begin
gc = GibbsConditional(:x, c -> Normal(0, 1))
@test Turing.Inference.isgibbscomponent(gc)
end
end
end

end
Loading
Loading