Skip to content

Commit 67ff8e8

Browse files
committed
results is wrong
1 parent 55dbab5 commit 67ff8e8

File tree

2 files changed

+28
-69
lines changed

2 files changed

+28
-69
lines changed

gibbs_example/gibbs.jl

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,6 @@ using OrderedCollections
44

55
##
66

7-
# TODO: introduce some kind of parameter format, for instance, a flattened vector
8-
# then define some kind of function to transform the flattened vector into model's representation
9-
107
struct Gibbs <: AbstractMCMC.AbstractSampler
118
sampler_map::OrderedDict
129
end
@@ -73,12 +70,12 @@ function AbstractMCMC.step(
7370
cond_val = NamedTuple{Tuple(group_complement)}(
7471
Tuple([vi[g] for g in group_complement])
7572
)
73+
cond_logdensity = condition(logdensity_model.logdensity, cond_val)
74+
sub_state = recompute_logprob!!(cond_logdensity, getparams(sub_state), sub_state)
7675
sub_state = last(
7776
AbstractMCMC.step(
7877
rng,
79-
AbstractMCMC.LogDensityModel(
80-
condition(logdensity_model.logdensity, cond_val)
81-
),
78+
AbstractMCMC.LogDensityModel(cond_logdensity),
8279
sub_spl,
8380
sub_state,
8481
args...;
@@ -87,8 +84,8 @@ function AbstractMCMC.step(
8784
)
8885
state.states[group] = sub_state
8986
end
90-
for sub_state in values(state.states)
91-
vi = merge(vi, getparams(sub_state))
87+
for (group, sub_state) in state.states
88+
vi = merge(vi, unflatten(getparams(sub_state), group))
9289
end
9390
return GibbsTransition(vi), GibbsState(vi, state.states)
9491
end
@@ -103,9 +100,16 @@ samples = sample(
103100
OrderedDict(
104101
(:z,) => PriorMH(product_distribution([Categorical([0.3, 0.7]) for _ in 1:60])),
105102
(:w,) => PriorMH(Dirichlet(2, 1.0)),
106-
(, :w) => RWMH(1),
103+
(,) => RWMH(1),
107104
),
108105
),
109-
10000;
106+
100000;
110107
initial_params=(z=rand(Categorical([0.3, 0.7]), 60), μ=[0.0, 1.0], w=[0.3, 0.7]),
111-
)
108+
);
109+
110+
z_samples = [sample.values.z for sample in samples][20001:end]
111+
μ_samples = [sample.values.μ for sample in samples][20001:end]
112+
w_samples = [sample.values.w for sample in samples][20001:end]
113+
114+
mean(μ_samples)
115+
mean(w_samples)

gibbs_example/gmm.jl

Lines changed: 13 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -44,42 +44,16 @@ function condition(gmm::GMM, conditioned_values::NamedTuple)
4444
return ConditionedGMM(gmm.data, conditioned_values)
4545
end
4646

47-
function _logdensity(gmm::Union{ConditionedGMM{(:μ, :w)},ConditionedGMM{(:w, :μ)}}, params)
48-
return log_joint(;
49-
μ=gmm.conditioned_values.μ, w=gmm.conditioned_values.w, z=params.z, x=gmm.data.x
50-
)
51-
end
52-
53-
function _logdensity(gmm::ConditionedGMM{(:z,)}, params)
54-
return log_joint(; μ=params.μ, w=params.w, z=gmm.conditioned_values.z, x=gmm.data.x)
55-
end
56-
57-
function LogDensityProblems.logdensity(
58-
gmm::Union{ConditionedGMM{(:μ, :w)},ConditionedGMM{(:w, :μ)}},
59-
params_vec::AbstractVector,
60-
)
61-
@assert length(params_vec) == 60
62-
return _logdensity(gmm, (; z=params_vec))
63-
end
64-
function LogDensityProblems.logdensity(
65-
gmm::ConditionedGMM{(:z,)}, params_vec::AbstractVector
66-
)
67-
@assert length(params_vec) == 4 "length(params_vec) = $(length(params_vec))"
68-
return _logdensity(gmm, (; μ=params_vec[1:2], w=params_vec[3:4]))
69-
end
70-
71-
function LogDensityProblems.dimension(gmm::GMM)
72-
return 4 + size(gmm.data.x, 1)
73-
end
74-
75-
function LogDensityProblems.dimension(
76-
gmm::Union{ConditionedGMM{(:μ, :w)},ConditionedGMM{(:w, :μ)}}
77-
)
78-
return 4
79-
end
80-
81-
function LogDensityProblems.dimension(gmm::ConditionedGMM{(:z,)})
82-
return size(gmm.data.x, 1)
47+
function LogDensityProblems.logdensity(gmm::ConditionedGMM{names}, params::AbstractVector) where {names}
48+
if Set(names) == Set([, :w]) # conditioned on μ, w, so params are z
49+
return log_joint(; μ=gmm.conditioned_values.μ, w=gmm.conditioned_values.w, z=params, x=gmm.data.x)
50+
elseif Set(names) == Set([:z, :w]) # conditioned on z, w, so params are μ
51+
return log_joint(; μ=params, w=gmm.conditioned_values.w, z=gmm.conditioned_values.z, x=gmm.data.x)
52+
elseif Set(names) == Set([:z, ]) # conditioned on z, μ, so params are w
53+
return log_joint(; μ=gmm.conditioned_values.μ, w=params, z=gmm.conditioned_values.z, x=gmm.data.x)
54+
else
55+
error("Unsupported conditioning configuration.")
56+
end
8357
end
8458

8559
function LogDensityProblems.capabilities(::GMM)
@@ -91,41 +65,22 @@ function LogDensityProblems.capabilities(::ConditionedGMM)
9165
end
9266

9367
function flatten(nt::NamedTuple)
94-
if Set(keys(nt)) == Set([, :w])
95-
return vcat(nt.μ, nt.w)
96-
elseif Set(keys(nt)) == Set([:z])
97-
return nt.z
98-
else
99-
error()
100-
end
68+
return only(values(nt))
10169
end
10270

10371
function unflatten(vec::AbstractVector, group::Tuple)
104-
if Set(group) == Set([, :w])
105-
return (; μ=vec[1:2], w=vec[3:4])
106-
elseif Set(group) == Set([:z])
107-
return (; z=vec)
108-
else
109-
error()
110-
end
72+
return NamedTuple((only(group) => vec,))
11173
end
11274

113-
# sampler's states to internal representation
114-
# ? who gets to define the output of `getparams`? (maybe have a `getparams(T, state)`?)
115-
116-
# the point here is that the parameter values are not changed, but because the context was changed, the logprob need to be recomputed
11775
function recompute_logprob!!(gmm::ConditionedGMM, vals, state)
118-
return setlogp!(state, _logdensity(gmm, vals))
76+
return setlogp!!(state, LogDensityProblems.logdensity(gmm, vals))
11977
end
12078

12179
## test using Turing
12280

12381
# data generation
12482

125-
using Distributions
12683
using FillArrays
127-
using LinearAlgebra
128-
using Random
12984

13085
w = [0.5, 0.5]
13186
μ = [-3.5, 0.5]

0 commit comments

Comments
 (0)