Skip to content
Open
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 40 additions & 112 deletions src/mcmc/gibbs_conditional.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,17 +54,48 @@ chain = sample(model, Gibbs(
), 1000)
```
"""
struct GibbsConditional{S,C} <: InferenceAlgorithm
struct GibbsConditional{C} <: InferenceAlgorithm
conditional::C

function GibbsConditional(sym::Symbol, conditional::C) where {C}
return new{sym,C}(conditional)
return new{C}(conditional)
end
end

# Mark GibbsConditional as a valid Gibbs component
isgibbscomponent(::GibbsConditional) = true

# Required methods for Gibbs constructor
Base.length(::GibbsConditional) = 1 # Each GibbsConditional handles one variable
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this being called somewhere? Might be, but I don't remember having a need for length of samplers.


"""
find_global_varinfo(context, fallback_vi)

Traverse the context stack to find global variable information from
GibbsContext, ConditionContext, FixedContext, etc.
"""
function find_global_varinfo(context, fallback_vi)
# Start with the given context and traverse down
current_context = context

while current_context !== nothing
if current_context isa GibbsContext
# Found GibbsContext, return its global varinfo
return get_global_varinfo(current_context)
elseif hasproperty(current_context, :childcontext) &&
isdefined(DynamicPPL, :childcontext)
Comment on lines +85 to +86
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This check can be done more robustly with the DynamicPPL.NodeTrait function, that is meant to check exactly this. If you write this as a recursion rather than as a loop, you can use multiple dispatch on the IsLeaf and IsParent types to implement this check. This is best explained by an example, see for instance the implementation of hassampler here: https://github.com/TuringLang/DynamicPPL.jl/blob/0cf3440549c7634c9b810e1f804eb2cd195f7e47/src/contexts.jl#L174

To explain a bit: DynamicPPL.NodeTrait(ctx) should return either IsLeaf or IsParent for all contexts ctx. Crucially those are types, not values, which means you can dispatch methods on them. The contract is that if DynamicPPL.NodeTrait(ctx) returns IsParent then there should be an implementation of childcontext(ctx) which returns the child context.

This design pattern is often called Holy Traits (after Tim Holy who presumably came up with them). It's kinda funky but cool, and worth learning, since it comes up with some frequency in Julia code. The benefit of Holy Traits is that the type of information of e.g. IsLeaf vs IsParent is known at compile time, and thus results in more efficient code.

# Move to child context if it exists
current_context = DynamicPPL.childcontext(current_context)
else
# No more child contexts
break
end
end

# If no GibbsContext found, use the fallback
return fallback_vi
end

"""
DynamicPPL.initialstep(rng, model, sampler::GibbsConditional, vi)

Expand Down Expand Up @@ -97,12 +128,10 @@ function AbstractMCMC.step(
alg = sampler.alg

# For GibbsConditional within Gibbs, we need to get all variable values
# Check if we're in a Gibbs context
global_vi = if hasproperty(model, :context) && model.context isa GibbsContext
# We're in a Gibbs context, get the global varinfo
get_global_varinfo(model.context)
# Traverse the context stack to find all conditioned/fixed/Gibbs variables
global_vi = if hasproperty(model, :context)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since model is a DynamicPPL.Model, it has to always have the field context (that's part of the definition of DynamicPPL.Model), and thus this check is unnecessary.

find_global_varinfo(model.context, state)
else
# We're not in a Gibbs context, use the current state
state
end

Expand All @@ -119,34 +148,10 @@ function AbstractMCMC.step(
updated = rand(rng, conddist)

# Update the variable in state
# We need to get the actual VarName for this variable
# The symbol S tells us which variable to update
vn = VarName{S}()

# Check if the variable needs to be a vector
new_vi = if haskey(state, vn)
# Update the existing variable
DynamicPPL.setindex!!(state, updated, vn)
else
# Try to find the variable with indices
# This handles cases where the variable might have indices
local updated_vi = state
found = false
for key in keys(state)
if DynamicPPL.getsym(key) == S
updated_vi = DynamicPPL.setindex!!(state, updated, key)
found = true
break
end
end
if !found
error("Could not find variable $S in VarInfo")
end
updated_vi
end

# Update log joint probability
new_vi = last(DynamicPPL.evaluate!!(model, new_vi, DynamicPPL.DefaultContext()))
# The Gibbs sampler ensures that state only contains one variable
# Get the variable name from the keys
varname = first(keys(state))
new_vi = DynamicPPL.setindex!!(state, updated, varname)

return nothing, new_vi
end
Expand All @@ -166,80 +171,3 @@ function setparams_varinfo!!(
# the state is nothing and we don't need to update anything
return params
end

"""
gibbs_initialstep_recursive(
rng, model, sampler::GibbsConditional, target_varnames, global_vi, prev_state
)

Initialize the GibbsConditional sampler.
"""
function gibbs_initialstep_recursive(
rng::Random.AbstractRNG,
model::DynamicPPL.Model,
sampler_wrapped::DynamicPPL.Sampler{<:GibbsConditional},
target_varnames::AbstractVector{<:VarName},
global_vi::DynamicPPL.AbstractVarInfo,
prev_state,
)
# GibbsConditional doesn't need any special initialization
# Just perform one sampling step
return gibbs_step_recursive(
rng, model, sampler_wrapped, target_varnames, global_vi, nothing
)
end

"""
gibbs_step_recursive(
rng, model, sampler::GibbsConditional, target_varnames, global_vi, state
)

Perform a single step of GibbsConditional sampling.
"""
function gibbs_step_recursive(
rng::Random.AbstractRNG,
model::DynamicPPL.Model,
sampler_wrapped::DynamicPPL.Sampler{<:GibbsConditional{S}},
target_varnames::AbstractVector{<:VarName},
global_vi::DynamicPPL.AbstractVarInfo,
state,
) where {S}
sampler = sampler_wrapped.alg

# Extract conditioned values as a NamedTuple
# Include both random variables and observed data
condvals_vars = DynamicPPL.values_as(DynamicPPL.invlink(global_vi, model), NamedTuple)
condvals_obs = NamedTuple{keys(model.args)}(model.args)
condvals = merge(condvals_vars, condvals_obs)

# Get the conditional distribution
conddist = sampler.conditional(condvals)

# Sample from the conditional distribution
updated = rand(rng, conddist)

# Update the variable in global_vi
# We need to get the actual VarName for this variable
# The symbol S tells us which variable to update
vn = VarName{S}()

# Check if the variable needs to be a vector
if haskey(global_vi, vn)
# Update the existing variable
global_vi = DynamicPPL.setindex!!(global_vi, updated, vn)
else
# Try to find the variable with indices
# This handles cases where the variable might have indices
for key in keys(global_vi)
if DynamicPPL.getsym(key) == S
global_vi = DynamicPPL.setindex!!(global_vi, updated, key)
break
end
end
end

# Update log joint probability
global_vi = last(DynamicPPL.evaluate!!(model, global_vi, DynamicPPL.DefaultContext()))

return nothing, global_vi
end
Loading