Skip to content

Commit f7813b2

Browse files
phipsgablerdevmotiontrappmartin
authored
Gibbs Conditionals (new PR) (#1275)
Co-authored-by: David Widmann <[email protected]> Co-authored-by: Martin Trapp <[email protected]>
1 parent 99ba159 commit f7813b2

File tree

8 files changed

+255
-2
lines changed

8 files changed

+255
-2
lines changed

src/Turing.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ export @model, # modelling
9191
Emcee,
9292
ESS,
9393
Gibbs,
94+
GibbsConditional,
9495

9596
HMC, # Hamiltonian-like sampling
9697
SGLD,

src/inference/Inference.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ export InferenceAlgorithm,
4242
ESS,
4343
Emcee,
4444
Gibbs, # classic sampling
45+
GibbsConditional,
4546
HMC,
4647
SGLD,
4748
SGHMC,
@@ -538,6 +539,7 @@ include("mh.jl")
538539
include("is.jl")
539540
include("AdvancedSMC.jl")
540541
include("gibbs.jl")
542+
include("gibbs_conditional.jl")
541543
include("../contrib/inference/sghmc.jl")
542544
include("emcee.jl")
543545

src/inference/gibbs.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ function Sampler(alg::Gibbs, model::Model, s::Selector)
8080
else
8181
prev_alg = alg.algs[i-1]
8282
end
83-
rerun = !(_alg isa MH) || prev_alg isa PG || prev_alg isa ESS
83+
rerun = !(_alg isa MH) || prev_alg isa PG || prev_alg isa ESS || prev_alg isa GibbsConditional
8484
selector = Selector(Symbol(typeof(_alg)), rerun)
8585
Sampler(_alg, model, selector)
8686
end

src/inference/gibbs_conditional.jl

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
"""
2+
GibbsConditional(sym, conditional)
3+
4+
A "pseudo-sampler" to manually provide analytical Gibbs conditionals to `Gibbs`.
5+
`GibbsConditional(:x, cond)` will sample the variable `x` according to the conditional `cond`, which
6+
must therefore be a function from a `NamedTuple` of the conditioned variables to a `Distribution`.
7+
8+
9+
The `NamedTuple` that is passed in contains all random variables from the model in an unspecified
10+
order, taken from the [`VarInfo`](@ref) object over which the model is run. Scalars and vectors are
11+
stored in their respective shapes. The tuple also contains the value of the conditioned variable
12+
itself, which can be useful, but using it creates something that is not a Gibbs sampler anymore (see
13+
[here](https://github.com/TuringLang/Turing.jl/pull/1275#discussion_r434240387)).
14+
15+
# Examples
16+
17+
```julia
18+
α_0 = 2.0
19+
θ_0 = inv(3.0)
20+
x = [1.5, 2.0]
21+
N = length(x)
22+
23+
@model function inverse_gdemo(x)
24+
λ ~ Gamma(α_0, θ_0)
25+
σ = sqrt(1 / λ)
26+
m ~ Normal(0, σ)
27+
@. x ~ \$(Normal(m, σ))
28+
end
29+
30+
# The conditionals can be formulated in terms of the following statistics:
31+
x_bar = mean(x) # sample mean
32+
s2 = var(x; mean=x_bar, corrected=false) # sample variance
33+
m_n = N * x_bar / (N + 1)
34+
35+
function cond_m(c)
36+
λ_n = c.λ * (N + 1)
37+
σ_n = sqrt(1 / λ_n)
38+
return Normal(m_n, σ_n)
39+
end
40+
41+
function cond_λ(c)
42+
α_n = α_0 + (N - 1) / 2 + 1
43+
β_n = s2 * N / 2 + c.m^2 / 2 + inv(θ_0)
44+
return Gamma(α_n, inv(β_n))
45+
end
46+
47+
m = inverse_gdemo(x)
48+
49+
sample(m, Gibbs(GibbsConditional(:λ, cond_λ), GibbsConditional(:m, cond_m)), 10)
50+
```
51+
"""
52+
struct GibbsConditional{S, C}
53+
conditional::C
54+
55+
function GibbsConditional(sym::Symbol, conditional::C) where {C}
56+
return new{sym, C}(conditional)
57+
end
58+
end
59+
60+
DynamicPPL.getspace(::GibbsConditional{S}) where {S} = (S,)
61+
DynamicPPL.alg_str(::GibbsConditional) = "GibbsConditional"
62+
isgibbscomponent(::GibbsConditional) = true
63+
64+
65+
function Sampler(
66+
alg::GibbsConditional,
67+
model::Model,
68+
s::Selector=Selector()
69+
)
70+
return Sampler(alg, Dict{Symbol, Any}(), s, SamplerState(VarInfo(model)))
71+
end
72+
73+
74+
function AbstractMCMC.step!(
75+
rng::AbstractRNG,
76+
model::Model,
77+
spl::Sampler{<:GibbsConditional{S}},
78+
N::Integer,
79+
transition;
80+
kwargs...
81+
) where {S}
82+
if spl.selector.rerun # Recompute joint in logp
83+
model(spl.state.vi)
84+
end
85+
86+
condvals = conditioned(tonamedtuple(spl.state.vi))
87+
conddist = spl.alg.conditional(condvals)
88+
updated = rand(rng, conddist)
89+
spl.state.vi[VarName(S)] = [updated;] # setindex allows only vectors in this case...
90+
91+
return transition
92+
end
93+
94+
95+
"""
96+
conditioned(θ::NamedTuple)
97+
98+
Extract a `NamedTuple` of the values in `θ`; i.e., all names of `θ`, mapping to their respective
99+
values.
100+
101+
`θ` is assumed to come from `tonamedtuple(vi)`, which returns a `NamedTuple` of the form
102+
103+
```julia
104+
t = (m = ([0.234, -1.23], ["m[1]", "m[2]"]), λ = ([1.233], ["λ"])
105+
```
106+
107+
and this function implements the cleanup of indexing. `conditioned(t)` will therefore return
108+
109+
```julia
110+
(λ = 1.233, m = [0.234, -1.23])
111+
```
112+
"""
113+
@generated function conditioned::NamedTuple{names}) where {names}
114+
condvals = [:($n = extractparam(θ.$n)) for n in names]
115+
return Expr(:tuple, condvals...)
116+
end
117+
118+
119+
"""Takes care of removing the `tonamedtuple` indexing form."""
120+
extractparam(p::Tuple{Vector{<:Array{<:Real}}, Vector{String}}) = foldl(vcat, p[1])
121+
function extractparam(p::Tuple{Vector{<:Real}, Vector{String}})
122+
values, strings = p
123+
if length(values) == length(strings) == 1 && !occursin(r".\[.+\]$", strings[1])
124+
# if m ~ MVNormal(1, 1), we could have have ([1], ["m[1]"])!
125+
return values[1]
126+
else
127+
return values
128+
end
129+
end

test/Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
33
AdvancedMH = "5b7e9947-ddc0-4b3f-9b55-0d8042f74170"
44
AdvancedVI = "b5ca4192-6429-45e5-a2d9-87aec30a685c"
5+
Clustering = "aaaa29a8-35af-508c-8bc3-b662a17a0fe5"
56
CmdStan = "593b3428-ca2f-500c-ae53-031589ec8ddd"
67
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
78
DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
@@ -29,6 +30,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
2930
AbstractMCMC = "1.0.1"
3031
AdvancedMH = "0.5.1"
3132
AdvancedVI = "0.1"
33+
Clustering = "0.14"
3234
CmdStan = "6.0.8"
3335
Distributions = "0.23.8"
3436
DistributionsAD = "0.6.3"

test/inference/gibbs.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,6 @@ include(dir*"/test/test_utils/AllUtils.jl")
114114
alg = Gibbs(MH(:s), HMC(0.2, 4, :m))
115115
sample(model, alg, 100; callback = callback)
116116
end
117-
118117
@turing_testset "dynamic model" begin
119118
@model imm(y, alpha, ::Type{M}=Vector{Float64}) where {M} = begin
120119
N = length(y)

test/inference/gibbs_conditional.jl

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
using Random, Turing, Test
2+
using Clustering
3+
4+
dir = splitdir(splitdir(pathof(Turing))[1])[1]
5+
include(dir*"/test/test_utils/AllUtils.jl")
6+
7+
8+
@turing_testset "gibbs conditionals" begin
9+
Random.seed!(100)
10+
11+
@turing_testset "gdemo" begin
12+
N = 1000
13+
(α_0, θ_0) = (2.0, inv(3.0))
14+
λ_true = rand(Gamma(α_0, θ_0))
15+
σ_true = sqrt(1 / λ_true)
16+
m_true = rand(Normal(0, σ_true))
17+
x = rand(Normal(m_true, σ_true), N)
18+
19+
# The conditionals and posterior can be formulated in terms of the following statistics:
20+
x_bar = mean(x) # sample mean
21+
s2 = var(x; mean=x_bar, corrected=false) # sample variance
22+
m_n = N * x_bar / (N + 1)
23+
24+
@model function inverse_gdemo(x)
25+
λ ~ Gamma(α_0, θ_0)
26+
σ = sqrt(1 / λ)
27+
m ~ Normal(0, σ)
28+
@. x ~ $(Normal(m, σ))
29+
end
30+
31+
function cond_m(c)
32+
λ_n = c.λ * (N + 1)
33+
σ_n = sqrt(1 / λ_n)
34+
return Normal(m_n, σ_n)
35+
end
36+
37+
function cond_λ(c)
38+
α_n = α_0 + (N - 1) / 2 + 1
39+
β_n = s2 * N / 2 + c.m^2 / 2 + inv(θ_0)
40+
return Gamma(α_n, inv(β_n))
41+
end
42+
43+
# Three tests: one for each variable fixed to the true value, and one for both
44+
# using the conditional
45+
for alg in (Gibbs(GibbsConditional(:m, cond_m),
46+
GibbsConditional(, c -> Normal(λ_true, 0))),
47+
Gibbs(GibbsConditional(:m, c -> Normal(m_true, 0)),
48+
GibbsConditional(, cond_λ)),
49+
Gibbs(GibbsConditional(:m, cond_m),
50+
GibbsConditional(, cond_λ)))
51+
chain = sample(inverse_gdemo(x), alg, 10_000)
52+
check_numerical(chain, [:m, ], [m_true, λ_true], atol=0.2)
53+
end
54+
end
55+
56+
@turing_testset "GMM" begin
57+
π = [0.5, 0.5] # cluster weights
58+
K = length(π) # number of clusters
59+
m = 0.5 # prior mean
60+
s = 2.0 # prior variance
61+
σ = 0.1 # observation variance
62+
N = 20 # number of observations
63+
64+
μ_true = rand(Normal(m, s), K)
65+
z_true = rand(Categorical(π), N)
66+
x = rand(MvNormal(μ_true[z_true], σ))
67+
68+
@model function mixture(x)
69+
μ ~ MvNormal(fill(m, K), s)
70+
z ~ filldist(Categorical(π), N)
71+
x ~ MvNormal(μ[z], σ)
72+
return x
73+
end
74+
75+
# see http://www.cs.columbia.edu/~blei/fogm/2015F/notes/mixtures-and-gibbs.pdf
76+
function cond_z(c)
77+
function mixtureweight(x)
78+
p = π .* pdf.(Normal.(c.μ, σ), x)
79+
return p ./ sum(p)
80+
end
81+
return arraydist(Categorical.(mixtureweight.(x)))
82+
end
83+
84+
function cond_μ(c)
85+
z = c.z
86+
n = [count(z .== k) for k = 1:K]
87+
88+
# If there were no observations assigned to center `k`, `n[k] == 0`, and
89+
# we use the prior instead.
90+
s_hat = [(n[k] != 0) ? inv(n[k] / σ^2 + 1/s^2) : s for k = 1:K]
91+
μ_hat = [(n[k] != 0) ? (sum(x[z .== k]) / σ^2) * s_hat[k] : m for k = 1:K]
92+
93+
return MvNormal(μ_hat, s_hat)
94+
end
95+
96+
estimate(chain, var) = dropdims(mean(Array(group(chain, var)), dims=1), dims=1)
97+
function estimatez(chain, var, range)
98+
z = Int.(Array(group(chain, var)))
99+
return map(i -> findmax(counts(z[:,i], range))[2], 1:size(z,2))
100+
end
101+
102+
lμ_true, uμ_true = extrema(μ_true)
103+
104+
for alg in (Gibbs(GibbsConditional(:z, cond_z), GibbsConditional(, cond_μ)),
105+
Gibbs(GibbsConditional(:z, cond_z), MH()),
106+
Gibbs(GibbsConditional(:z, cond_z), HMC(0.01, 7, )), )
107+
108+
chain = sample(mixture(x), alg, 10000)
109+
110+
μ_hat = estimate(chain, )
111+
lμ_hat, uμ_hat = extrema(μ_hat)
112+
@test isapprox([lμ_true, uμ_true], [lμ_hat, uμ_hat], atol=0.1)
113+
114+
z_hat = estimatez(chain, :z, 1:2)
115+
ari, _, _, _ = randindex(z_true, Int.(z_hat))
116+
@test isapprox(ari, 1, atol=0.1)
117+
end
118+
end
119+
end

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ include("test_utils/AllUtils.jl")
2323
@testset "inference: $adbackend" begin
2424
@testset "samplers" begin
2525
include("inference/gibbs.jl")
26+
include("inference/gibbs_conditional.jl")
2627
include("inference/hmc.jl")
2728
include("inference/is.jl")
2829
include("inference/mh.jl")

0 commit comments

Comments
 (0)