Skip to content

Commit 3afc232

Browse files
committed
more progress; still need to deal with w being on simplex
1 parent 590d37f commit 3afc232

File tree

4 files changed

+258
-48
lines changed

4 files changed

+258
-48
lines changed

gibbs_example/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,5 +6,6 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
66
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
77
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
88
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
9+
StatsPlots = "f3b207a7-027a-5e70-b257-86293d7955fd"
910
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1011
Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"

gibbs_example/gibbs.jl

Lines changed: 83 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,18 @@
1+
using AbstractMCMC
12
using LogDensityProblems, Distributions, LinearAlgebra, Random
23
using OrderedCollections
34

5+
##
6+
7+
# TODO: introduce some kind of parameter format, for instance, a flattened vector
8+
# then define some kind of function to transform the flattened vector into model's representation
9+
410
struct Gibbs <: AbstractMCMC.AbstractSampler
511
sampler_map::OrderedDict
612
end
713

814
struct GibbsState
9-
values::NamedTuple
15+
vi::NamedTuple
1016
states::OrderedDict
1117
end
1218

@@ -15,31 +21,91 @@ struct GibbsTransition
1521
end
1622

1723
function AbstractMCMC.step(
18-
rng::AbstractRNG, model, sampler::Gibbs, args...; initial_params::NamedTuple, kwargs...
24+
rng::AbstractRNG,
25+
logdensity_model::AbstractMCMC.LogDensityModel,
26+
spl::Gibbs,
27+
args...;
28+
initial_params::NamedTuple,
29+
kwargs...,
1930
)
2031
states = OrderedDict()
21-
for group in keys(sampler.sampler_map)
22-
sampler = sampler.sampler_map[group]
23-
cond_val = NamedTuple{group}([initial_params[g] for g in group]...)
24-
trans, state = AbstractMCMC.step(
25-
rng, condition(model, cond_val), sampler, args...; kwargs...
32+
for group in keys(spl.sampler_map)
33+
sub_spl = spl.sampler_map[group]
34+
35+
vars_to_be_conditioned_on = setdiff(keys(initial_params), group)
36+
cond_val = NamedTuple{Tuple(vars_to_be_conditioned_on)}(
37+
Tuple([initial_params[g] for g in vars_to_be_conditioned_on])
38+
)
39+
params_val = NamedTuple{Tuple(group)}(Tuple([initial_params[g] for g in group]))
40+
sub_state = last(
41+
AbstractMCMC.step(
42+
rng,
43+
AbstractMCMC.LogDensityModel(
44+
condition(logdensity_model.logdensity, cond_val)
45+
),
46+
sub_spl,
47+
args...;
48+
initial_params=flatten(params_val),
49+
kwargs...,
50+
),
2651
)
27-
states[group] = state
52+
states[group] = sub_state
2853
end
2954
return GibbsTransition(initial_params), GibbsState(initial_params, states)
3055
end
3156

3257
function AbstractMCMC.step(
33-
rng::AbstractRNG, model, sampler::Gibbs, state::GibbsState, args...; kwargs...
58+
rng::AbstractRNG,
59+
logdensity_model::AbstractMCMC.LogDensityModel,
60+
spl::Gibbs,
61+
state::GibbsState,
62+
args...;
63+
kwargs...,
3464
)
35-
for group in collect(keys(sampler.sampler_map))
36-
sampler = sampler.sampler_map[group]
37-
state = state.states[group]
38-
trans, state = AbstractMCMC.step(
39-
rng, condition(model, state.values[group]), sampler, state, args...; kwargs...
65+
vi = state.vi
66+
for group in keys(spl.sampler_map)
67+
for (group, sub_state) in state.states
68+
vi = merge(vi, unflatten(getparams(sub_state), group))
69+
end
70+
sub_spl = spl.sampler_map[group]
71+
sub_state = state.states[group]
72+
group_complement = setdiff(keys(vi), group)
73+
cond_val = NamedTuple{Tuple(group_complement)}(
74+
Tuple([vi[g] for g in group_complement])
75+
)
76+
sub_state = last(
77+
AbstractMCMC.step(
78+
rng,
79+
AbstractMCMC.LogDensityModel(
80+
condition(logdensity_model.logdensity, cond_val)
81+
),
82+
sub_spl,
83+
sub_state,
84+
args...;
85+
kwargs...,
86+
),
4087
)
41-
# TODO: what values to condition on here? stored where?
42-
state.states[group] = state
88+
state.states[group] = sub_state
4389
end
44-
return nothing
90+
for sub_state in values(state.states)
91+
vi = merge(vi, getparams(sub_state))
92+
end
93+
return GibbsTransition(vi), GibbsState(vi, state.states)
4594
end
95+
96+
## tests
97+
98+
gmm = GMM((; x=x))
99+
100+
samples = sample(
101+
gmm,
102+
Gibbs(
103+
OrderedDict(
104+
(:z,) => PriorMH(product_distribution([Categorical([0.3, 0.7]) for _ in 1:60])),
105+
(:w,) => PriorMH(Dirichlet(2, 1.0)),
106+
(, :w) => RWMH(1),
107+
),
108+
),
109+
10000;
110+
initial_params=(z=rand(Categorical([0.3, 0.7]), 60), μ=[0.0, 1.0], w=[0.3, 0.7]),
111+
)

gibbs_example/gmm.jl

Lines changed: 52 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ function log_joint(; μ, w, z, x)
2929
logp += logpdf(w_prior, w)
3030

3131
z_prior = Categorical(w)
32+
3233
logp += sum([logpdf(z_prior, z[i]) for i in 1:N])
3334

3435
obs_priors = [MvNormal(fill(μₖ, D), I) for μₖ in μ]
@@ -43,33 +44,80 @@ function condition(gmm::GMM, conditioned_values::NamedTuple)
4344
return ConditionedGMM(gmm.data, conditioned_values)
4445
end
4546

46-
function _logdensity(gmm::ConditionedGMM{(:μ, :w)}, params)
47+
function _logdensity(gmm::Union{ConditionedGMM{(:μ, :w)},ConditionedGMM{(:w, :μ)}}, params)
4748
return log_joint(;
4849
μ=gmm.conditioned_values.μ, w=gmm.conditioned_values.w, z=params.z, x=gmm.data.x
4950
)
5051
end
52+
5153
function _logdensity(gmm::ConditionedGMM{(:z,)}, params)
5254
return log_joint(; μ=params.μ, w=params.w, z=gmm.conditioned_values.z, x=gmm.data.x)
5355
end
5456

5557
function LogDensityProblems.logdensity(
56-
gmm::ConditionedGMM{(:μ, :w)}, params_vec::AbstractVector
58+
gmm::Union{ConditionedGMM{(:μ, :w)},ConditionedGMM{(:w, :μ)}},
59+
params_vec::AbstractVector,
5760
)
61+
@assert length(params_vec) == 60
5862
return _logdensity(gmm, (; z=params_vec))
5963
end
6064
function LogDensityProblems.logdensity(
6165
gmm::ConditionedGMM{(:z,)}, params_vec::AbstractVector
6266
)
67+
@assert length(params_vec) == 4 "length(params_vec) = $(length(params_vec))"
6368
return _logdensity(gmm, (; μ=params_vec[1:2], w=params_vec[3:4]))
6469
end
6570

66-
function LogDensityProblems.dimension(gmm::ConditionedGMM{(:μ, :w)})
67-
return size(gmm.data.x, 1)
71+
function LogDensityProblems.dimension(gmm::GMM)
72+
return 4 + size(gmm.data.x, 1)
73+
end
74+
75+
function LogDensityProblems.dimension(
76+
gmm::Union{ConditionedGMM{(:μ, :w)},ConditionedGMM{(:w, :μ)}}
77+
)
78+
return 4
6879
end
80+
6981
function LogDensityProblems.dimension(gmm::ConditionedGMM{(:z,)})
7082
return size(gmm.data.x, 1)
7183
end
7284

85+
function LogDensityProblems.capabilities(::GMM)
86+
return LogDensityProblems.LogDensityOrder{0}()
87+
end
88+
89+
function LogDensityProblems.capabilities(::ConditionedGMM)
90+
return LogDensityProblems.LogDensityOrder{0}()
91+
end
92+
93+
function flatten(nt::NamedTuple)
94+
if Set(keys(nt)) == Set([, :w])
95+
return vcat(nt.μ, nt.w)
96+
elseif Set(keys(nt)) == Set([:z])
97+
return nt.z
98+
else
99+
error()
100+
end
101+
end
102+
103+
function unflatten(vec::AbstractVector, group::Tuple)
104+
if Set(group) == Set([, :w])
105+
return (; μ=vec[1:2], w=vec[3:4])
106+
elseif Set(group) == Set([:z])
107+
return (; z=vec)
108+
else
109+
error()
110+
end
111+
end
112+
113+
# sampler's states to internal representation
114+
# ? who gets to define the output of `getparams`? (maybe have a `getparams(T, state)`?)
115+
116+
# the point here is that the parameter values are not changed, but because the context was changed, the logprob need to be recomputed
117+
function recompute_logprob!!(gmm::ConditionedGMM, vals, state)
118+
return setlogp!(state, _logdensity(gmm, vals))
119+
end
120+
73121
## test using Turing
74122

75123
# data generation

0 commit comments

Comments
 (0)