Skip to content

Commit 590d37f

Browse files
committed
some updates
1 parent f05f293 commit 590d37f

File tree

6 files changed

+270
-180
lines changed

6 files changed

+270
-180
lines changed

gibbs_example/Project.toml

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
[deps]
2+
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
3+
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
4+
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
5+
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
6+
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
7+
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
8+
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
9+
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
10+
Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"

gibbs_example/gibbs.jl

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
using LogDensityProblems, Distributions, LinearAlgebra, Random
2+
using OrderedCollections
3+
4+
struct Gibbs <: AbstractMCMC.AbstractSampler
5+
sampler_map::OrderedDict
6+
end
7+
8+
struct GibbsState
9+
values::NamedTuple
10+
states::OrderedDict
11+
end
12+
13+
struct GibbsTransition
14+
values::NamedTuple
15+
end
16+
17+
function AbstractMCMC.step(
18+
rng::AbstractRNG, model, sampler::Gibbs, args...; initial_params::NamedTuple, kwargs...
19+
)
20+
states = OrderedDict()
21+
for group in keys(sampler.sampler_map)
22+
sampler = sampler.sampler_map[group]
23+
cond_val = NamedTuple{group}([initial_params[g] for g in group]...)
24+
trans, state = AbstractMCMC.step(
25+
rng, condition(model, cond_val), sampler, args...; kwargs...
26+
)
27+
states[group] = state
28+
end
29+
return GibbsTransition(initial_params), GibbsState(initial_params, states)
30+
end
31+
32+
function AbstractMCMC.step(
33+
rng::AbstractRNG, model, sampler::Gibbs, state::GibbsState, args...; kwargs...
34+
)
35+
for group in collect(keys(sampler.sampler_map))
36+
sampler = sampler.sampler_map[group]
37+
state = state.states[group]
38+
trans, state = AbstractMCMC.step(
39+
rng, condition(model, state.values[group]), sampler, state, args...; kwargs...
40+
)
41+
# TODO: what values to condition on here? stored where?
42+
state.states[group] = state
43+
end
44+
return nothing
45+
end

gibbs_example/gmm.jl

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
using LogDensityProblems
2+
3+
abstract type AbstractGMM end
4+
5+
struct GMM <: AbstractGMM
6+
data::NamedTuple
7+
end
8+
9+
struct ConditionedGMM{conditioned_vars} <: AbstractGMM
10+
data::NamedTuple
11+
conditioned_values::NamedTuple{conditioned_vars}
12+
end
13+
14+
function log_joint(; μ, w, z, x)
15+
# μ is mean of each component
16+
# w is weights of each component
17+
# z is assignment of each data point
18+
# x is data
19+
20+
K = 2 # assume we know the number of components
21+
D = 2 # dimension of each data point
22+
N = size(x, 2) # number of data points
23+
logp = 0.0
24+
25+
μ_prior = MvNormal(zeros(K), I)
26+
logp += logpdf(μ_prior, μ)
27+
28+
w_prior = Dirichlet(K, 1.0)
29+
logp += logpdf(w_prior, w)
30+
31+
z_prior = Categorical(w)
32+
logp += sum([logpdf(z_prior, z[i]) for i in 1:N])
33+
34+
obs_priors = [MvNormal(fill(μₖ, D), I) for μₖ in μ]
35+
for i in 1:N
36+
logp += logpdf(obs_priors[z[i]], x[:, i])
37+
end
38+
39+
return logp
40+
end
41+
42+
function condition(gmm::GMM, conditioned_values::NamedTuple)
43+
return ConditionedGMM(gmm.data, conditioned_values)
44+
end
45+
46+
function _logdensity(gmm::ConditionedGMM{(:μ, :w)}, params)
47+
return log_joint(;
48+
μ=gmm.conditioned_values.μ, w=gmm.conditioned_values.w, z=params.z, x=gmm.data.x
49+
)
50+
end
51+
function _logdensity(gmm::ConditionedGMM{(:z,)}, params)
52+
return log_joint(; μ=params.μ, w=params.w, z=gmm.conditioned_values.z, x=gmm.data.x)
53+
end
54+
55+
function LogDensityProblems.logdensity(
56+
gmm::ConditionedGMM{(:μ, :w)}, params_vec::AbstractVector
57+
)
58+
return _logdensity(gmm, (; z=params_vec))
59+
end
60+
function LogDensityProblems.logdensity(
61+
gmm::ConditionedGMM{(:z,)}, params_vec::AbstractVector
62+
)
63+
return _logdensity(gmm, (; μ=params_vec[1:2], w=params_vec[3:4]))
64+
end
65+
66+
function LogDensityProblems.dimension(gmm::ConditionedGMM{(:μ, :w)})
67+
return size(gmm.data.x, 1)
68+
end
69+
function LogDensityProblems.dimension(gmm::ConditionedGMM{(:z,)})
70+
return size(gmm.data.x, 1)
71+
end
72+
73+
## test using Turing
74+
75+
# data generation
76+
77+
using Distributions
78+
using FillArrays
79+
using LinearAlgebra
80+
using Random
81+
82+
w = [0.5, 0.5]
83+
μ = [-3.5, 0.5]
84+
mixturemodel = Distributions.MixtureModel([MvNormal(Fill(μₖ, 2), I) for μₖ in μ], w)
85+
86+
N = 60
87+
x = rand(mixturemodel, N);
88+
89+
# Turing model from https://turinglang.org/docs/tutorials/01-gaussian-mixture-model/
90+
using Turing
91+
92+
@model function gaussian_mixture_model(x)
93+
# Draw the parameters for each of the K=2 clusters from a standard normal distribution.
94+
K = 2
95+
μ ~ MvNormal(Zeros(K), I)
96+
97+
# Draw the weights for the K clusters from a Dirichlet distribution with parameters αₖ = 1.
98+
w ~ Dirichlet(K, 1.0)
99+
# Alternatively, one could use a fixed set of weights.
100+
# w = fill(1/K, K)
101+
102+
# Construct categorical distribution of assignments.
103+
distribution_assignments = Categorical(w)
104+
105+
# Construct multivariate normal distributions of each cluster.
106+
D, N = size(x)
107+
distribution_clusters = [MvNormal(Fill(μₖ, D), I) for μₖ in μ]
108+
109+
# Draw assignments for each datum and generate it from the multivariate normal distribution.
110+
k = Vector{Int}(undef, N)
111+
for i in 1:N
112+
k[i] ~ distribution_assignments
113+
x[:, i] ~ distribution_clusters[k[i]]
114+
end
115+
116+
return μ, w, k, __varinfo__
117+
end
118+
119+
model = gaussian_mixture_model(x);
120+
121+
using Test
122+
# full model
123+
μ, w, k, vi = model()
124+
@test log_joint(; μ=μ, w=w, z=k, x=x) DynamicPPL.getlogp(vi)
125+
126+
gmm = GMM((; x=x))
127+
128+
# cond model on μ, w
129+
μ, w, k, vi = (DynamicPPL.condition(model, (μ=μ, w=w)))()
130+
@test _logdensity(condition(gmm, (; μ=μ, w=w)), (; z=k)) DynamicPPL.getlogp(vi)
131+
132+
# cond model on z
133+
μ, w, k, vi = (DynamicPPL.condition(model, (z = k)))()
134+
@test _logdensity(condition(gmm, (; z=k)), (; μ=μ, w=w)) DynamicPPL.getlogp(vi)

gibbs_example/mh.jl

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
struct RWMH <: AbstractMCMC.AbstractSampler
2+
σ
3+
end
4+
5+
struct MHTransition{T} where {T}
6+
params::T
7+
end
8+
9+
struct MHState{T} where {T}
10+
params::T
11+
logp::Float64
12+
end
13+
14+
getparams(state::MHState) = state.params
15+
setparams!!(state::MHState, params) = MHState(params, state.logp)
16+
getlogp(state::MHState) = state.logp
17+
setlogp!!(state::MHState, logp) = MHState(state.params, logp)
18+
19+
function AbstractMCMC.step(rng::AbstractRNG, logdensity, sampler::RWMH, args...; kwargs...)
20+
params = rand(rng, LogDensityProblems.dimension(logdensity))
21+
return MHTransition(params),
22+
MHState(params, LogDensityProblems.logdensity(logdensity, params))
23+
end
24+
25+
function AbstractMCMC.step(
26+
rng::AbstractRNG, logdensity, sampler::RWMH, state::MHState, args...; kwargs...
27+
)
28+
params = getparams(state)
29+
proposal_dist = MvNormal(params, sampler.σ)
30+
proposal = rand(rng, proposal_dist)
31+
logp_proposal = logpdf(proposal_dist, proposal)
32+
accepted = log(rand(rng)) < log1pexp(min(0, logp_proposal - getlogp(state)))
33+
if accepted
34+
return MHTransition(proposal), MHState(proposal, logp_proposal)
35+
else
36+
return MHTransition(params), MHState(params, getlogp(state))
37+
end
38+
end
39+
40+
struct PriorMH <: AbstractMCMC.AbstractSampler
41+
prior_dist
42+
end
43+
44+
function AbstractMCMC.step(
45+
rng::AbstractRNG, logdensity, sampler::PriorMH, args...; kwargs...
46+
)
47+
params = rand(rng, sampler.prior_dist)
48+
return MHTransition(params), MHState(params, logdensity(params))
49+
end
50+
51+
function AbstractMCMC.step(
52+
rng::AbstractRNG, logdensity, sampler::PriorMH, state::MHState, args...; kwargs...
53+
)
54+
params = getparams(state)
55+
proposal_dist = sampler.prior_dist
56+
proposal = rand(rng, proposal_dist)
57+
logp_proposal = logpdf(proposal_dist, proposal)
58+
accepted = log(rand(rng)) < log1pexp(min(0, logp_proposal - getlogp(state)))
59+
if accepted
60+
return MHTransition(proposal), MHState(proposal, logp_proposal)
61+
else
62+
return MHTransition(params), MHState(params, getlogp(state))
63+
end
64+
end

src/AbstractMCMC.jl

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -84,11 +84,18 @@ The `MCMCSerial` algorithm allows users to sample serially, with no thread or pr
8484
struct MCMCSerial <: AbstractMCMCEnsemble end
8585

8686
"""
87-
recompute_logprob!!(rng, model, sampler, state)
87+
get_logprob(state)
8888
89-
Recompute the log-probability of the `model` based on the given `state` and return the resulting state.
89+
Returns the log-probability of the last sampling step, stored in `state`.
9090
"""
91-
function recompute_logprob!!(rng, model, sampler, state) end
91+
function get_logprob(state) end
92+
93+
"""
94+
set_logprob!(state, logprob)
95+
96+
Set the log-probability of the last sampling step, stored in `state`.
97+
"""
98+
function set_logprob!!(state, logprob) end
9299

93100
"""
94101
getparams(state)
@@ -97,6 +104,13 @@ Returns the values of the parameters in the state.
97104
"""
98105
function getparams(state) end
99106

107+
"""
108+
setparams!(state, params)
109+
110+
Set the values of the parameters in the state.
111+
"""
112+
function setparams!!(state, params) end
113+
100114
include("samplingstats.jl")
101115
include("logging.jl")
102116
include("interface.jl")

0 commit comments

Comments
 (0)