Skip to content

Commit f05f293

Browse files
committed
unfinished gibbs example
1 parent 95d781b commit f05f293

File tree

1 file changed

+177
-0
lines changed

1 file changed

+177
-0
lines changed

src/example.jl

Lines changed: 177 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,177 @@
1+
using LogDensityProblems, Distributions, LinearAlgebra, Random
2+
using OrderedCollections
3+
## Define a simple GMM problem
4+
5+
struct GMM{Tdata}
6+
data::NamedTuple
7+
end
8+
9+
struct ConditionedGMM{conditioned_vars}
10+
data::NamedTuple
11+
conditioned_values::NamedTuple{conditioned_vars}
12+
end
13+
14+
function log_joint(;μ, w, z, x)
15+
# μ is mean of each component
16+
# w is weights of each component
17+
# z is assignment of each data point
18+
# x is data
19+
20+
K = 2
21+
D = 2
22+
N = size(x, 1)
23+
logp = .0
24+
25+
μ_prior = MvNormal(zeros(K), I)
26+
logp += sum(logpdf(μ_prior, μ))
27+
28+
w_prior = Dirichlet(K, 1.0)
29+
logp += logpdf(w_prior, w)
30+
31+
z_prior = Categorical(w)
32+
logp += sum([logpdf(z_prior, z[i]) for i in 1:N])
33+
34+
for i in 1:N
35+
logp += logpdf(MvNormal(fill(μ[z[i]], D), I), x[i, :])
36+
end
37+
38+
return logp
39+
end
40+
41+
function condition(gmm::GMM, conditioned_values::NamedTuple)
42+
return ConditionedGMM(gmm.data, conditioned_values)
43+
end
44+
45+
function logdensity(gmm::ConditionedGMM{conditioned_vars}, params) where {conditioned_vars}
46+
if conditioned_vars == (, :w)
47+
return log_joint(;μ=gmm.conditioned_values.μ, w=gmm.conditioned_values.w, z=params.z, x=gmm.data)
48+
elseif conditioned_vars == (:z,)
49+
return log_joint(;μ=params.μ, w=params.w, z=gmm.conditioned_values.z, x=gmm.data)
50+
else
51+
throw(ArgumentError("condition group not supported"))
52+
end
53+
end
54+
55+
function LogDensityProblems.logdensity(gmm::ConditionedGMM{conditioned_vars}, params_vec::AbstractVector) where {conditioned_vars}
56+
if conditioned_vars == (, :w)
57+
params = (; z= params_vec)
58+
elseif conditioned_vars == (:z,)
59+
params = (; μ= params_vec[1:2], w= params_vec[3:4])
60+
else
61+
throw(ArgumentError("condition group not supported"))
62+
end
63+
64+
return logdensity(gmm, params)
65+
end
66+
67+
function LogDensityProblems.dimension(gmm::ConditionedGMM{conditioned_vars}) where {conditioned_vars}
68+
if conditioned_vars == (, :w)
69+
return size(gmm.data.x, 1)
70+
elseif conditioned_vars == (:z,)
71+
return size(gmm.data.x, 1)
72+
else
73+
throw(ArgumentError("condition group not supported"))
74+
end
75+
end
76+
77+
struct Gibbs <: AbstractMCMC.AbstractSampler
78+
sampler_map::OrderedDict
79+
end
80+
81+
# ! initialize the params here
82+
struct GibbsState
83+
"contains all the values of the model parameters"
84+
values::NamedTuple
85+
states::OrderedDict
86+
end
87+
88+
struct GibbsTransition
89+
values::NamedTuple
90+
end
91+
92+
function AbstractMCMC.step(
93+
rng::AbstractRNG, model, sampler::Gibbs, args...; initial_params::NamedTuple, kwargs...
94+
)
95+
states = OrderedDict()
96+
for group in collect(keys(sampler.sampler_map))
97+
sampler = sampler.sampler_map[group]
98+
cond_val = NamedTuple{group}([initial_params[g] for g in group]...)
99+
trans, state = AbstractMCMC.step(rng, condition(model, cond_val), sampler, args...; kwargs...)
100+
states[group] = state
101+
end
102+
return GibbsTransition(initial_params), GibbsState(initial_params, states)
103+
end
104+
105+
# questions is: when do we assume the logp from last iteration is not reliable anymore
106+
107+
function AbstractMCMC.step(
108+
rng::AbstractRNG, model, sampler::Gibbs, state::GibbsState, args...; kwargs...
109+
)
110+
for group in collect(keys(sampler.sampler_map))
111+
sampler = sampler.sampler_map[group]
112+
state = state.states[group]
113+
trans, state = AbstractMCMC.step(rng, condition(model, state.values[group]), sampler, state, args...; kwargs...)
114+
# TODO: what values to condition on here? stored where?
115+
state.states[group] = state
116+
end
117+
return
118+
end
119+
120+
# importance sampling
121+
struct ImportanceSampling <: AbstractMCMC.AbstractSampler
122+
"number of samples"
123+
n::Int
124+
proposal
125+
end
126+
127+
struct ImportanceSamplingState
128+
129+
end
130+
131+
struct ImportanceSamplingTransition
132+
values
133+
end
134+
135+
# initial step
136+
function AbstractMCMC.step(
137+
rng::AbstractRNG, logdensity, sampler::ImportanceSampling, args...; kwargs...
138+
)
139+
140+
end
141+
142+
function IS_step(rng::AbstractRNG, logdensity, sampler::ImportanceSampling, state::ImportanceSamplingState, args...; kwargs...)
143+
proposals = rand(rng, sampler.proposal, sampler.n)
144+
weights = logdensity.(proposals) .- log.(logpdf.(sampler.proposal, proposals))
145+
sample = rand(rng, Categorical(softmax(weights)))
146+
return ImportanceSamplingTransition(proposals[sample]), ImportanceSamplingState()
147+
end
148+
149+
150+
function AbstractMCMC.step(
151+
rng::AbstractRNG, logdensity, sampler::ImportanceSampling, state::ImportanceSamplingState, args...; kwargs...
152+
)
153+
return
154+
end
155+
156+
struct RWMH <: AbstractMCMC.AbstractSampler
157+
proposal
158+
end
159+
160+
function AbstractMCMC.step(
161+
rng::AbstractRNG, logdensity, sampler::RWMH, args...; kwargs...
162+
)
163+
proposal = rand(rng, sampler.proposal)
164+
165+
acceptance_probability = min(1, exp(logdensity(proposal) - logdensity(args[1])))
166+
if rand(rng) < acceptance_probability
167+
return proposal, nothing
168+
else
169+
return args[1], nothing
170+
end
171+
end
172+
173+
function AbstractMCMC.step(
174+
rng::AbstractRNG, logdensity, sampler::RWMH, state::RWMHState, args...; kwargs...
175+
)
176+
return
177+
end

0 commit comments

Comments
 (0)