Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
c0158ea
Add GibbsConditional sampler and corresponding tests
Aug 7, 2025
a972b5a
clarified comment
Aug 7, 2025
bdb7f73
Merge branch 'main' into gibbs-sampler
mhauru Aug 7, 2025
c3cc773
add MHs suggestions
Aug 8, 2025
714c1e8
formatter
Aug 8, 2025
97c571d
Merge branch 'gibbs-sampler' of github.com:TuringLang/Turing.jl into …
Aug 8, 2025
94b723d
fixed exporting thing
Aug 8, 2025
891ac14
Merge branch 'main' into gibbs-sampler
AoifeHughes Aug 18, 2025
2058ae5
Refactor Gibbs sampler to use inverse of parameters for Gamma distrib…
Sep 23, 2025
b0812a3
removed file added by mistake
Sep 25, 2025
d910312
Add safety checks and error handling in find_global_varinfo and Abstr…
Sep 29, 2025
4b1dc2f
imports?
Oct 9, 2025
1e84309
Merge remote-tracking branch 'origin/main' into gibbs-sampler
mhauru Nov 17, 2025
a33d8a9
Fixes and improvements for GibbsConditional
mhauru Nov 17, 2025
f41fc6e
Move GibbsConditional tests to their own file
mhauru Nov 17, 2025
bd5ff0b
More GibbsConditional tests
mhauru Nov 17, 2025
34acad7
Bump patch version to 0.41.2, add HISTORY.md entry
mhauru Nov 17, 2025
4786a59
Remove spurious change
mhauru Nov 17, 2025
8951d98
Code style and documentation
mhauru Nov 18, 2025
45ab5f8
Add one test_throws, tweak test thresholds and dimensions
mhauru Nov 18, 2025
805bc60
Apply suggestions from code review
mhauru Nov 18, 2025
d0c3cf4
Add links for where to get analytical posteriors
mhauru Nov 18, 2025
98f4213
Update TODO note
mhauru Nov 18, 2025
c74b0a0
Fix a GibbsConditional bug, add a test
mhauru Nov 18, 2025
4a7d08c
Set seeds better
mhauru Nov 18, 2025
744d254
Use getvalue in docstring
mhauru Nov 19, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions HISTORY.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,11 @@
# 0.41.2

Add `GibbsConditional`, a "sampler" that can be used to provide analytically known conditional posteriors in a Gibbs sampler.

In Gibbs sampling, some variables are sampled with a component sampler, while holding other variables conditioned to their current values. Usually one e.g. takes turns sampling one variable with HMC and the other with a particle sampler. However, sometimes the posterior distribution of one variable is known analytically, given the conditioned values of other variables. `GibbsConditional` provides a way to implement these analytically known conditional posteriors and use them as component samplers for Gibbs. See the docstring of `GibbsConditional` for details.

Note that `GibbsConditional` used to exist in Turing.jl until v0.36, at which it was removed when the whole Gibbs sampler was rewritten. This reintroduces the same functionality, though with a slightly different interface.

# 0.41.1

The `ModeResult` struct returned by `maximum_a_posteriori` and `maximum_likelihood` can now be wrapped in `InitFromParams()`.
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "Turing"
uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
version = "0.41.1"
version = "0.41.2"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
1 change: 1 addition & 0 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ even though [`Prior()`](@ref) is actually defined in the `Turing.Inference` modu
| `Emcee` | [`Turing.Inference.Emcee`](@ref) | Affine-invariant ensemble sampler |
| `ESS` | [`Turing.Inference.ESS`](@ref) | Elliptical slice sampling |
| `Gibbs` | [`Turing.Inference.Gibbs`](@ref) | Gibbs sampling |
| `GibbsConditional` | [`Turing.Inference.GibbsConditional`](@ref) | Gibbs sampling with analytical conditional posterior distributions |
| `HMC` | [`Turing.Inference.HMC`](@ref) | Hamiltonian Monte Carlo |
| `SGLD` | [`Turing.Inference.SGLD`](@ref) | Stochastic gradient Langevin dynamics |
| `SGHMC` | [`Turing.Inference.SGHMC`](@ref) | Stochastic gradient Hamiltonian Monte Carlo |
Expand Down
1 change: 1 addition & 0 deletions src/Turing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ export
Emcee,
ESS,
Gibbs,
GibbsConditional,
HMC,
SGLD,
SGHMC,
Expand Down
2 changes: 2 additions & 0 deletions src/mcmc/Inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ export Hamiltonian,
ESS,
Emcee,
Gibbs, # classic sampling
GibbsConditional, # conditional sampling
HMC,
SGLD,
PolynomialStepsize,
Expand Down Expand Up @@ -430,6 +431,7 @@ include("mh.jl")
include("is.jl")
include("particle_mcmc.jl")
include("gibbs.jl")
include("gibbs_conditional.jl")
include("sghmc.jl")
include("emcee.jl")
include("prior.jl")
Expand Down
171 changes: 171 additions & 0 deletions src/mcmc/gibbs_conditional.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
"""
GibbsConditional(get_cond_dists)

A Gibbs component sampler that samples variables according to user-provided analytical
conditional posterior distributions.

When using Gibbs sampling, sometimes one may know the analytical form of the posterior for
a given variable, given the conditioned values of the other variables. In such cases one can
use `GibbsConditional` as a component sampler to to sample from these known conditionals
directly, avoiding any MCMC methods. One does so with

```julia
sampler = Gibbs(
(@varname(var1), @varname(var2)) => GibbsConditional(get_cond_dists),
other samplers go here...
)
```

Here `get_cond_dists(c::Dict{<:VarName})` should be a function that takes a `Dict` mapping
the conditioned variables (anything other than `var1` and `var2`) to their values, and
returns the conditional posterior distributions for `var1` and `var2`. You may, of course,
have any number of variables being sampled as a block in this manner, we only use two as an
example. The return value of `get_cond_dists` should be one of the following:
- A single `Distribution`, if only one variable is being sampled.
- An `AbstractDict{<:VarName,<:Distribution}` that maps the variables being sampled to their
conditional posteriors E.g. `Dict(@varname(var1) => dist1, @varname(var2) => dist2)`.
- A `NamedTuple` of `Distribution`s, which is like the `AbstractDict` case but can be used
if all the variable names are single `Symbol`s, and may be more performant. E.g.
`(; var1=dist1, var2=dist2)`.

# Examples

```julia
# Define a model
@model function inverse_gdemo(x)
precision ~ Gamma(2, inv(3))
std = sqrt(1 / precision)
m ~ Normal(0, std)
for i in eachindex(x)
x[i] ~ Normal(m, std)
end
end

# Define analytical conditionals. See
# https://en.wikipedia.org/wiki/Conjugate_prior#When_likelihood_function_is_a_continuous_distribution
function cond_precision(c)
a = 2.0
b = 3.0
# We use AbstractPPL.getvalue instead of indexing into `c` directly to guard against
# issues where e.g. you try to get `c[@varname(x[1])]` but only `@varname(x)` is present
# in `c`. `getvalue` handles that gracefully, `getindex` doesn't. In this case
# `getindex` would suffice, but `getvalue` is good practice.
m = AbstractPPL.getvalue(c, @varname(m))
x = AbstractPPL.getvalue(c, @varname(x))
n = length(x)
a_new = a + (n + 1) / 2
b_new = b + sum(abs2, x .- m) / 2 + m^2 / 2
return Gamma(a_new, 1 / b_new)
end

function cond_m(c)
precision = AbstractPPL.getvalue(c, @varname(precision))
x = AbstractPPL.getvalue(c, @varname(x))
n = length(x)
m_mean = sum(x) / (n + 1)
m_var = 1 / (precision * (n + 1))
return Normal(m_mean, sqrt(m_var))
end

# Sample using GibbsConditional
model = inverse_gdemo([1.0, 2.0, 3.0])
chain = sample(model, Gibbs(
:precision => GibbsConditional(cond_precision),
:m => GibbsConditional(cond_m)
), 1000)
```
"""
struct GibbsConditional{C} <: AbstractSampler
get_cond_dists::C
end

isgibbscomponent(::GibbsConditional) = true

"""
build_variable_dict(model::DynamicPPL.Model)

Traverse the context stack of `model` and build a `Dict` of all the variable values that are
set in GibbsContext, ConditionContext, or FixedContext.
"""
function build_variable_dict(model::DynamicPPL.Model)
context = model.context
cond_vals = DynamicPPL.conditioned(context)
fixed_vals = DynamicPPL.fixed(context)
# TODO(mhauru) Can we avoid invlinking all the time?
global_vi = DynamicPPL.invlink(get_gibbs_global_varinfo(context), model)
# TODO(mhauru) This creates a lot of Dicts, which are then immediately merged into one.
# Also, DynamicPPL.to_varname_dict is known to be inefficient. Make a more efficient
# implementation.
return merge(
DynamicPPL.values_as(global_vi, Dict),
DynamicPPL.to_varname_dict(cond_vals),
DynamicPPL.to_varname_dict(fixed_vals),
DynamicPPL.to_varname_dict(model.args),
)
end

function get_gibbs_global_varinfo(context::DynamicPPL.AbstractContext)
return if context isa GibbsContext
get_global_varinfo(context)
elseif DynamicPPL.NodeTrait(context) isa DynamicPPL.IsParent
get_gibbs_global_varinfo(DynamicPPL.childcontext(context))
else
msg = """No GibbsContext found in context stack. Are you trying to use \
GibbsConditional outside of Gibbs?
"""
throw(ArgumentError(msg))
end
end

function initialstep(
::Random.AbstractRNG,
model::DynamicPPL.Model,
::GibbsConditional,
vi::DynamicPPL.AbstractVarInfo;
kwargs...,
)
state = DynamicPPL.is_transformed(vi) ? DynamicPPL.invlink(vi, model) : vi
# Since GibbsConditional is only used within Gibbs, it does not need to return a
# transition.
return nothing, state
end

function AbstractMCMC.step(
rng::Random.AbstractRNG,
model::DynamicPPL.Model,
sampler::GibbsConditional,
state::DynamicPPL.AbstractVarInfo;
kwargs...,
)
# Get all the conditioned variable values from the model context. This is assumed to
# include a GibbsContext as part of the context stack.
condvals = build_variable_dict(model)
conddists = sampler.get_cond_dists(condvals)

# We support three different kinds of return values for `sample.get_cond_dists`, to make
# life easier for the user.
if conddists isa AbstractDict
for (vn, dist) in conddists
state = setindex!!(state, rand(rng, dist), vn)
end
elseif conddists isa NamedTuple
for (vn_sym, dist) in pairs(conddists)
vn = VarName{vn_sym}()
state = setindex!!(state, rand(rng, dist), vn)
end
else
# Single variable case
vn = only(keys(state))
state = setindex!!(state, rand(rng, conddists), vn)
end

# Since GibbsConditional is only used within Gibbs, it does not need to return a
# transition.
return nothing, state
end

function setparams_varinfo!!(
::DynamicPPL.Model, ::GibbsConditional, ::Any, params::DynamicPPL.AbstractVarInfo
)
return params
end
4 changes: 2 additions & 2 deletions test/mcmc/gibbs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -496,7 +496,7 @@ end

@testset "dynamic model with analytical posterior" begin
# A dynamic model where b ~ Bernoulli determines the dimensionality
# When b=0: single parameter θ₁
# When b=0: single parameter θ₁
# When b=1: two parameters θ₁, θ₂ where we observe their sum
@model function dynamic_bernoulli_normal(y_obs=2.0)
b ~ Bernoulli(0.3)
Expand Down Expand Up @@ -575,7 +575,7 @@ end
# end
# end
# sample(f(), Gibbs(:a => PG(10), :x => MH()), 1000)
#
#
# because the number of observations in each particle depends on the value
# of `a`.
#
Expand Down
Loading
Loading