Skip to content

Commit 6132f0c

Browse files
committed
update code
1 parent 62a2332 commit 6132f0c

File tree

2 files changed

+29
-44
lines changed

2 files changed

+29
-44
lines changed

test/gibbs_example/gibbs.jl

Lines changed: 25 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,28 @@
1+
using AbstractMCMC, AbstractPPL
2+
using BangBang.ConstructorBase: ConstructorBase
3+
14
"""
25
Gibbs(sampler_map::NamedTuple)
36
4-
An interface for block sampling in Markov Chain Monte Carlo (MCMC).
5-
6-
Gibbs sampling is a technique for dividing complex multivariate problems into simpler subproblems.
7-
It allows different sampling methods to be applied to different parameters.
7+
A Gibbs sampler that allows for block sampling using different inference algorithms for each parameter.
88
"""
9-
struct Gibbs{NT<:NamedTuple} <: AbstractMCMC.AbstractSampler
10-
sampler_map::NT
9+
struct Gibbs{T<:NamedTuple} <: AbstractMCMC.AbstractSampler
10+
sampler_map::T
1111
end
1212

1313
struct GibbsState{TraceNT<:NamedTuple,StateNT<:NamedTuple,SizeNT<:NamedTuple}
14-
"""
15-
Contains the values of all parameters up to the last iteration.
16-
"""
14+
"Contains the values of all parameters up to the last iteration."
1715
trace::TraceNT
1816

19-
"""
20-
Maps parameters to their sampler-specific MCMC states.
21-
"""
17+
"Maps parameters to their sampler-specific MCMC states."
2218
mcmc_states::StateNT
2319

24-
"""
25-
Maps parameters to their sizes.
26-
"""
20+
"Maps parameters to their sizes."
2721
variable_sizes::SizeNT
2822
end
2923

3024
struct GibbsTransition{ValuesNT<:NamedTuple}
31-
"""
32-
Realizations of the parameters, this is considered a "sample" in the MCMC chain.
33-
"""
25+
"Realizations of the parameters, this is considered a \"sample\" in the MCMC chain."
3426
values::ValuesNT
3527
end
3628

@@ -95,7 +87,7 @@ Update the trace with the values from the MCMC states of the sub-problems.
9587
function update_trace(trace::NamedTuple, gibbs_state::GibbsState)
9688
for parameter_variable in keys(gibbs_state.mcmc_states)
9789
sub_state = gibbs_state.mcmc_states[parameter_variable]
98-
sub_state_params = vec(sub_state)
90+
sub_state_params = Base.vec(sub_state)
9991
unflattened_sub_state_params = unflatten(
10092
sub_state_params,
10193
NamedTuple{(parameter_variable,)}((
@@ -115,21 +107,19 @@ function AbstractMCMC.step(
115107
initial_params::NamedTuple,
116108
kwargs...,
117109
)
118-
if Set(keys(initial_params)) != Set(sampler.parameter_names)
110+
if Set(keys(initial_params)) != Set(keys(sampler.sampler_map))
119111
throw(
120112
ArgumentError(
121-
"initial_params must contain all parameters in the model, expected $(sampler.parameter_names), got $(keys(initial_params))",
113+
"initial_params must contain all parameters in the model, expected $(keys(sampler.sampler_map)), got $(keys(initial_params))",
122114
),
123115
)
124116
end
125117

126-
mcmc_states = Dict{Symbol,Any}()
127-
variable_sizes = Dict{Symbol,Tuple}()
128-
for parameter_variable in sampler.parameter_names
118+
mcmc_states, variable_sizes = map(keys(sampler.sampler_map)) do parameter_variable
129119
sub_sampler = sampler.sampler_map[parameter_variable]
130120

131121
variables_to_be_conditioned_on = setdiff(
132-
sampler.parameter_names, (parameter_variable,)
122+
keys(sampler.sampler_map), (parameter_variable,)
133123
)
134124
conditioning_variables_values = NamedTuple{Tuple(variables_to_be_conditioned_on)}(
135125
Tuple([initial_params[g] for g in variables_to_be_conditioned_on])
@@ -141,7 +131,6 @@ function AbstractMCMC.step(
141131
# LogDensityProblems' `logdensity` function expects a single vector of real numbers
142132
# `Gibbs` stores the parameters as a named tuple, thus we need to flatten the sub_problem_parameters_values
143133
# and unflatten after the sampling step
144-
variable_sizes[parameter_variable] = Tuple(size(initial_params[parameter_variable]))
145134
flattened_sub_problem_parameters_values = flatten(sub_problem_parameters_values)
146135

147136
sub_state = last(
@@ -158,11 +147,13 @@ function AbstractMCMC.step(
158147
kwargs...,
159148
),
160149
)
161-
mcmc_states[parameter_variable] = sub_state
150+
(sub_state, Tuple(size(initial_params[parameter_variable])))
162151
end
163152

164153
gibbs_state = GibbsState(
165-
initial_params, NamedTuple(mcmc_states), NamedTuple(variable_sizes)
154+
initial_params,
155+
NamedTuple{Tuple(keys(sampler.sampler_map))}(mcmc_states),
156+
NamedTuple{Tuple(keys(sampler.sampler_map))}(variable_sizes),
166157
)
167158
trace = update_trace(NamedTuple(), gibbs_state)
168159
return GibbsTransition(trace), gibbs_state
@@ -176,14 +167,9 @@ function AbstractMCMC.step(
176167
args...;
177168
kwargs...,
178169
)
179-
trace = gibbs_state.trace
180-
mcmc_states = gibbs_state.mcmc_states
181-
variable_sizes = gibbs_state.variable_sizes
170+
(; trace, mcmc_states, variable_sizes) = gibbs_state
182171

183-
mcmc_states_dict = Dict(
184-
keys(mcmc_states) .=> [mcmc_states[k] for k in keys(mcmc_states)]
185-
)
186-
for parameter_variable in sampler.parameter_names
172+
mcmc_states = map(keys(sampler.sampler_map)) do parameter_variable
187173
sub_sampler = sampler.sampler_map[parameter_variable]
188174
sub_state = mcmc_states[parameter_variable]
189175
variables_to_be_conditioned_on = setdiff(
@@ -196,7 +182,8 @@ function AbstractMCMC.step(
196182
logdensity_model.logdensity, conditioning_variables_values
197183
)
198184

199-
_, sub_state = AbstractMCMC.logdensity_and_state(cond_logdensity, sub_state)
185+
logp = LogDensityProblems.logdensity_and_state(cond_logdensity, sub_state)
186+
sub_state = constructorof(typeof(sub_state))(; logp=logp)
200187
sub_state = last(
201188
AbstractMCMC.step(
202189
rng,
@@ -207,12 +194,10 @@ function AbstractMCMC.step(
207194
kwargs...,
208195
),
209196
)
210-
mcmc_states_dict[parameter_variable] = sub_state
211197
trace = update_trace(trace, gibbs_state)
198+
sub_state
212199
end
200+
mcmc_states = NamedTuple{Tuple(keys(sampler.sampler_map))}(mcmc_states)
213201

214-
mcmc_states = NamedTuple{Tuple(keys(mcmc_states_dict))}(
215-
Tuple([mcmc_states_dict[k] for k in keys(mcmc_states_dict)])
216-
)
217202
return GibbsTransition(trace), GibbsState(trace, mcmc_states, variable_sizes)
218203
end

test/gibbs_example/mh.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,12 @@ function Base.vec(state::MHState)
2424
return state.params
2525
end
2626

27-
struct RandomWalkMH <: AbstractMCMC.AbstractSampler
28-
σ::Float64
27+
struct RandomWalkMH{T} <: AbstractMCMC.AbstractSampler
28+
σ::T
2929
end
3030

31-
struct IndependentMH <: AbstractMCMC.AbstractSampler
32-
proposal_dist::Distributions.Distribution
31+
struct IndependentMH{T} <: AbstractMCMC.AbstractSampler
32+
proposal_dist::T
3333
end
3434

3535
function AbstractMCMC.step(

0 commit comments

Comments
 (0)