Skip to content
Open
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
c0158ea
Add GibbsConditional sampler and corresponding tests
Aug 7, 2025
a972b5a
clarified comment
Aug 7, 2025
bdb7f73
Merge branch 'main' into gibbs-sampler
mhauru Aug 7, 2025
c3cc773
add MHs suggestions
Aug 8, 2025
714c1e8
formatter
Aug 8, 2025
97c571d
Merge branch 'gibbs-sampler' of github.com:TuringLang/Turing.jl into …
Aug 8, 2025
94b723d
fixed exporting thing
Aug 8, 2025
891ac14
Merge branch 'main' into gibbs-sampler
AoifeHughes Aug 18, 2025
2058ae5
Refactor Gibbs sampler to use inverse of parameters for Gamma distrib…
Sep 23, 2025
b0812a3
removed file added by mistake
Sep 25, 2025
d910312
Add safety checks and error handling in find_global_varinfo and Abstr…
Sep 29, 2025
4b1dc2f
imports?
Oct 9, 2025
1e84309
Merge remote-tracking branch 'origin/main' into gibbs-sampler
mhauru Nov 17, 2025
a33d8a9
Fixes and improvements for GibbsConditional
mhauru Nov 17, 2025
f41fc6e
Move GibbsConditional tests to their own file
mhauru Nov 17, 2025
bd5ff0b
More GibbsConditional tests
mhauru Nov 17, 2025
34acad7
Bump patch version to 0.41.2, add HISTORY.md entry
mhauru Nov 17, 2025
4786a59
Remove spurious change
mhauru Nov 17, 2025
8951d98
Code style and documentation
mhauru Nov 18, 2025
45ab5f8
Add one test_throws, tweak test thresholds and dimensions
mhauru Nov 18, 2025
805bc60
Apply suggestions from code review
mhauru Nov 18, 2025
d0c3cf4
Add links for where to get analytical posteriors
mhauru Nov 18, 2025
98f4213
Update TODO note
mhauru Nov 18, 2025
c74b0a0
Fix a GibbsConditional bug, add a test
mhauru Nov 18, 2025
4a7d08c
Set seeds better
mhauru Nov 18, 2025
744d254
Use getvalue in docstring
mhauru Nov 19, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions HISTORY.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,11 @@
# 0.41.2

Add `GibbsConditional`, a "sampler" that can be used to provide analytically known conditional posteriors in a Gibbs sampler.

In Gibbs sampling, some variables are sampled with a component sampler, while holding other variables conditioned to their current values. Usually one e.g. takes turns sampling one variable with HMC and the other with a particle sampler. However, sometimes the posterior distribution of one variable is known analytically, given the conditioned values of other variables. `GibbsConditional` provides a way to implement these analytically known conditional posteriors and use them as component samplers for Gibbs. See the docstring of `GibbsConditional` for details.

Note that `GibbsConditional` used to exist in Turing.jl until v0.36, at which it was removed when the whole Gibbs sampler was rewritten. This reintroduces the same functionality, though with a slightly different interface.

# 0.41.1

The `ModeResult` struct returned by `maximum_a_posteriori` and `maximum_likelihood` can now be wrapped in `InitFromParams()`.
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "Turing"
uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
version = "0.41.1"
version = "0.41.2"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
1 change: 1 addition & 0 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ even though [`Prior()`](@ref) is actually defined in the `Turing.Inference` modu
| `Emcee` | [`Turing.Inference.Emcee`](@ref) | Affine-invariant ensemble sampler |
| `ESS` | [`Turing.Inference.ESS`](@ref) | Elliptical slice sampling |
| `Gibbs` | [`Turing.Inference.Gibbs`](@ref) | Gibbs sampling |
| `GibbsConditional` | [`Turing.Inference.GibbsConditional`](@ref) | Gibbs sampling with analytical conditional posterior distributions |
| `HMC` | [`Turing.Inference.HMC`](@ref) | Hamiltonian Monte Carlo |
| `SGLD` | [`Turing.Inference.SGLD`](@ref) | Stochastic gradient Langevin dynamics |
| `SGHMC` | [`Turing.Inference.SGHMC`](@ref) | Stochastic gradient Hamiltonian Monte Carlo |
Expand Down
1 change: 1 addition & 0 deletions src/Turing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ export
Emcee,
ESS,
Gibbs,
GibbsConditional,
HMC,
SGLD,
SGHMC,
Expand Down
2 changes: 2 additions & 0 deletions src/mcmc/Inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ export Hamiltonian,
ESS,
Emcee,
Gibbs, # classic sampling
GibbsConditional, # conditional sampling
HMC,
SGLD,
PolynomialStepsize,
Expand Down Expand Up @@ -430,6 +431,7 @@ include("mh.jl")
include("is.jl")
include("particle_mcmc.jl")
include("gibbs.jl")
include("gibbs_conditional.jl")
include("sghmc.jl")
include("emcee.jl")
include("prior.jl")
Expand Down
166 changes: 166 additions & 0 deletions src/mcmc/gibbs_conditional.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
"""
GibbsConditional(get_cond_dists)

A Gibbs component sampler that samples variables according to user-provided analytical
conditional posterior distributions.

When using Gibbs sampling, sometimes one may know the analytical form of the posterior for
a given variable, given the conditioned values of the other variables. In such cases one can
use `GibbsConditional` as a component sampler to to sample from these known conditionals
directly, avoiding any MCMC methods. One does so with

```julia
sampler = Gibbs(
(@varname(var1), @varname(var2)) => GibbsConditional(get_cond_dists),
other samplers go here...
)
```

Here `get_cond_dists(c::Dict{<:VarName})` should be a function that takes a `Dict` mapping
the conditioned variables (anything other than `var1` and `var2`) to their values, and
returns the conditional posterior distributions for `var1` and `var2`. You may, of course,
have any number of variables being sampled as a block in this manner, we only use two as an
example. The return value of `get_cond_dists` should be one of the following:
- A single `Distribution`, if only one variable is being sampled.
- An `AbstractDict{<:VarName,<:Distribution}` that maps the variables being sampled to their
conditional posteriors E.g. `Dict(@varname(var1) => dist1, @varname(var2) => dist2)`.
- A `NamedTuple` of `Distribution`s, which is like the `AbstractDict` case but can be used
if all the variable names are single `Symbol`s, and may be more performant. E.g.
`(; var1=dist1, var2=dist2)`.

# Examples

```julia
# Define a model
@model function inverse_gdemo(x)
precision ~ Gamma(2, inv(3))
std = sqrt(1 / precision)
m ~ Normal(0, std)
for i in 1:length(x)
x[i] ~ Normal(m, std)
end
end

# Define analytical conditionals
function cond_precision(c)
a = 2.0
b = inv(3)
m = c[@varname(m)]
x = c[@varname(x)]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel that the c[@varname(x)] here is deceptively simple... it will work under the following scenarios

  1. x is passed as a single vector to the model arguments
  2. x is not in the arguments, but is conditioned on as a vector i.e. model() | (; x = vec)

it will fail in the case where

  1. x is not in the arguments, but is conditioned on as individual elements i.e. model() | Dict(@varname(x[1]) => 1.0, @varname(x[2]) => 2.0)

Notice that because the model has x[i] ~ Normal, all three cases will work correctly with plain model evaluation. But one of the cases will fail with GibbsConditional.

I don't really know how to fix this, and I don't know whether it should even be fixed, but it makes me quite uncomfortable.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've fixed all the other points you raised. This will have to wait untill tomorrow.

Copy link
Member

@penelopeysm penelopeysm Nov 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As a simple enough solution, could we maybe take the model arguments out of the dictionary, and instead give c a field called c.model_args = deepcopy(model.args)? Then people can use c.model_args.x. Deepcopy would be needed to avoid accidental aliasing. (Or we could not copy, and leave it to the user to copy if they need it.)

Copy link
Member

@penelopeysm penelopeysm Nov 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The outcome of this is:

  1. x is passed as a single vector to the model arguments

OK, people can use c.model_args.x, clear.

  1. x is not in the arguments, but is conditioned on as a vector i.e. model() | (; x = vec)

OK, people can use c[@varname(x)] because that's what they conditioned on.

  1. x is not in the arguments, but is conditioned on as individual elements i.e. model() | Dict(@varname(x[1]) => 1.0, @varname(x[2]) => 2.0)

This will still fail but at least there's a good explanation for it, it's because they conditioned on x[1] and x[2] rather than x.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking at my previous comment, actually the original implementation is also explainable along the same lines. I guess the tldr is basically, if you supply x as a single thing, that's fine. So actually I'd be OK with leaving it as is.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The root cause of this is essentially the same as why you can't do

@model function f()
    x ~ MvNormal()
end

m = condition(f(), Dict(@varname(x[1]) => 1.0))

it's just that in that case you request the value of the conditioned x in the model, whereas here you request it in the conditionals function. So in some sense this is a bigger problem that needs to be fixed lower down (https://github.com/TuringLang/DynamicPPL.jl/issues/11480). So I'd be okay with leaving this be too.

The one thing that I was wondering about is that ConditionContext internally uses getvalue. If the user did the same in the conditionals function then at least they could get x1 = getvalue(c, @varname(x[1])) to work even if they conditioned on x as a whole. Thus I'm now thinking that maybe the best thing to do here would be to set a good example in the docstring and use getvalue rather than getindex, and leave it at that.

As a silver lining, if this goes wrong, at least the user gets a clear "key not found" error, and the error comes from code they wrote, in quite a traceable way.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, that sounds good to me! Agree that getvalue is better, albeit a tiny bit more annoying since it has to be imported from AbstractPPL

n = length(x)
a_new = a + (n + 1) / 2
b_new = b + sum((x[i] - m)^2 for i in 1:n) / 2 + m^2 / 2
return Gamma(a_new, 1 / b_new)
end

function cond_m(c)
precision = c[@varname(precision)]
x = c[@varname(x)]
n = length(x)
m_mean = sum(x) / (n + 1)
m_var = 1 / (precision * (n + 1))
return Normal(m_mean, sqrt(m_var))
end

# Sample using GibbsConditional
model = inverse_gdemo([1.0, 2.0, 3.0])
chain = sample(model, Gibbs(
:precision => GibbsConditional(cond_precision),
:m => GibbsConditional(cond_m)
), 1000)
```
"""
struct GibbsConditional{C} <: AbstractSampler
get_cond_dists::C
end

isgibbscomponent(::GibbsConditional) = true

"""
build_variable_dict(model::DynamicPPL.Model)

Traverse the context stack of `model` and build a `Dict` of all the variable values that are
set in GibbsContext, ConditionContext, or FixedContext.
"""
function build_variable_dict(model::DynamicPPL.Model)
context = model.context
cond_nt = DynamicPPL.conditioned(context)
fixed_nt = DynamicPPL.fixed(context)
# TODO(mhauru) Can we avoid invlinking all the time? Note that this causes a model
# evaluation, which may be expensive.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think for typed VarInfo it shouldn't need to evaluate the model. Obviously it still has a cost, just not as much as model evaluation.

julia> @model function f()
           @info "hi"
           x ~ Normal()
       end
f (generic function with 2 methods)

julia> model = f(); v = VarInfo(model);
[ Info: hi

julia> v2 = DynamicPPL.link!!(v, model); v3 = DynamicPPL.invlink!!(v, model);

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adjusted the comment to reflect this.

global_vi = DynamicPPL.invlink(get_gibbs_global_varinfo(context), model)
return merge(
DynamicPPL.values_as(global_vi, Dict),
Dict(
(DynamicPPL.VarName{sym}() => val for (sym, val) in pairs(cond_nt))...,
(DynamicPPL.VarName{sym}() => val for (sym, val) in pairs(fixed_nt))...,
(DynamicPPL.VarName{sym}() => val for (sym, val) in pairs(model.args))...,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cond_nt and fixed_nt might not be NamedTuples, they might be dicts, in which case this will fail in a very weird manner.

julia> VarName{@varname(x)}()
x

julia> VarName{@varname(x)}() == @varname(x) # not the same thing
false

I think you have to convert them to dicts first. DynamicPPL.to_varname_dict will do this (albeit inefficiently TuringLang/DynamicPPL.jl#1134).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice catch. I fixed this and added a test that would have caught this.

Making so many Dicts is sad, but for now I just want to get this working, and worry about performance later.

),
)
end

function get_gibbs_global_varinfo(context::DynamicPPL.AbstractContext)
return if context isa GibbsContext
get_global_varinfo(context)
elseif DynamicPPL.NodeTrait(context) isa DynamicPPL.IsParent
get_gibbs_global_varinfo(DynamicPPL.childcontext(context))
else
msg = """No GibbsContext found in context stack. Are you trying to use \
GibbsConditional outside of Gibbs?
"""
throw(ArgumentError(msg))
end
end

function initialstep(
::Random.AbstractRNG,
model::DynamicPPL.Model,
::GibbsConditional,
vi::DynamicPPL.AbstractVarInfo;
kwargs...,
)
state = DynamicPPL.is_transformed(vi) ? DynamicPPL.invlink(vi, model) : vi
# Since GibbsConditional is only used within Gibbs, it does not need to return a
# transition.
return nothing, state
end

function AbstractMCMC.step(
rng::Random.AbstractRNG,
model::DynamicPPL.Model,
sampler::GibbsConditional,
state::DynamicPPL.AbstractVarInfo;
kwargs...,
)
# Get all the conditioned variable values from the model context. This is assumed to
# include a GibbsContext as part of the context stack.
condvals = build_variable_dict(model)
conddists = sampler.get_cond_dists(condvals)

# We support three different kinds of return values for `sample.get_cond_dists`, to make
# life easier for the user.
if conddists isa AbstractDict
for (vn, dist) in conddists
state = setindex!!(state, rand(rng, dist), vn)
end
elseif conddists isa NamedTuple
for (vn_sym, dist) in pairs(conddists)
vn = VarName{vn_sym}()
state = setindex!!(state, rand(rng, dist), vn)
end
else
# Single variable case
vn = only(keys(state))
state = setindex!!(state, rand(rng, conddists), vn)
end

# Since GibbsConditional is only used within Gibbs, it does not need to return a
# transition.
return nothing, state
end

function setparams_varinfo!!(
::DynamicPPL.Model, ::GibbsConditional, ::Any, params::DynamicPPL.AbstractVarInfo
)
return params
end
4 changes: 2 additions & 2 deletions test/mcmc/gibbs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -496,7 +496,7 @@ end

@testset "dynamic model with analytical posterior" begin
# A dynamic model where b ~ Bernoulli determines the dimensionality
# When b=0: single parameter θ₁
# When b=0: single parameter θ₁
# When b=1: two parameters θ₁, θ₂ where we observe their sum
@model function dynamic_bernoulli_normal(y_obs=2.0)
b ~ Bernoulli(0.3)
Expand Down Expand Up @@ -575,7 +575,7 @@ end
# end
# end
# sample(f(), Gibbs(:a => PG(10), :x => MH()), 1000)
#
#
# because the number of observations in each particle depends on the value
# of `a`.
#
Expand Down
Loading
Loading