1
+ using AbstractMCMC, AbstractPPL
2
+ using BangBang. ConstructorBase: ConstructorBase
3
+
1
4
"""
2
5
Gibbs(sampler_map::NamedTuple)
3
6
4
- An interface for block sampling in Markov Chain Monte Carlo (MCMC).
5
-
6
- Gibbs sampling is a technique for dividing complex multivariate problems into simpler subproblems.
7
- It allows different sampling methods to be applied to different parameters.
7
+ A Gibbs sampler that allows for block sampling using different inference algorithms for each parameter.
8
8
"""
9
- struct Gibbs{NT <: NamedTuple } <: AbstractMCMC.AbstractSampler
10
- sampler_map:: NT
9
+ struct Gibbs{T <: NamedTuple } <: AbstractMCMC.AbstractSampler
10
+ sampler_map:: T
11
11
end
12
12
13
13
struct GibbsState{TraceNT<: NamedTuple ,StateNT<: NamedTuple ,SizeNT<: NamedTuple }
14
- """
15
- Contains the values of all parameters up to the last iteration.
16
- """
14
+ " Contains the values of all parameters up to the last iteration."
17
15
trace:: TraceNT
18
16
19
- """
20
- Maps parameters to their sampler-specific MCMC states.
21
- """
17
+ " Maps parameters to their sampler-specific MCMC states."
22
18
mcmc_states:: StateNT
23
19
24
- """
25
- Maps parameters to their sizes.
26
- """
20
+ " Maps parameters to their sizes."
27
21
variable_sizes:: SizeNT
28
22
end
29
23
30
24
struct GibbsTransition{ValuesNT<: NamedTuple }
31
- """
32
- Realizations of the parameters, this is considered a "sample" in the MCMC chain.
33
- """
25
+ " Realizations of the parameters, this is considered a \" sample\" in the MCMC chain."
34
26
values:: ValuesNT
35
27
end
36
28
@@ -95,7 +87,7 @@ Update the trace with the values from the MCMC states of the sub-problems.
95
87
function update_trace (trace:: NamedTuple , gibbs_state:: GibbsState )
96
88
for parameter_variable in keys (gibbs_state. mcmc_states)
97
89
sub_state = gibbs_state. mcmc_states[parameter_variable]
98
- sub_state_params = vec (sub_state)
90
+ sub_state_params = Base . vec (sub_state)
99
91
unflattened_sub_state_params = unflatten (
100
92
sub_state_params,
101
93
NamedTuple {(parameter_variable,)} ((
@@ -115,21 +107,19 @@ function AbstractMCMC.step(
115
107
initial_params:: NamedTuple ,
116
108
kwargs... ,
117
109
)
118
- if Set (keys (initial_params)) != Set (sampler. parameter_names )
110
+ if Set (keys (initial_params)) != Set (keys ( sampler. sampler_map) )
119
111
throw (
120
112
ArgumentError (
121
- " initial_params must contain all parameters in the model, expected $(sampler. parameter_names ) , got $(keys (initial_params)) " ,
113
+ " initial_params must contain all parameters in the model, expected $(keys ( sampler. sampler_map) ) , got $(keys (initial_params)) " ,
122
114
),
123
115
)
124
116
end
125
117
126
- mcmc_states = Dict {Symbol,Any} ()
127
- variable_sizes = Dict {Symbol,Tuple} ()
128
- for parameter_variable in sampler. parameter_names
118
+ mcmc_states, variable_sizes = map (keys (sampler. sampler_map)) do parameter_variable
129
119
sub_sampler = sampler. sampler_map[parameter_variable]
130
120
131
121
variables_to_be_conditioned_on = setdiff (
132
- sampler. parameter_names , (parameter_variable,)
122
+ keys ( sampler. sampler_map) , (parameter_variable,)
133
123
)
134
124
conditioning_variables_values = NamedTuple {Tuple(variables_to_be_conditioned_on)} (
135
125
Tuple ([initial_params[g] for g in variables_to_be_conditioned_on])
@@ -141,7 +131,6 @@ function AbstractMCMC.step(
141
131
# LogDensityProblems' `logdensity` function expects a single vector of real numbers
142
132
# `Gibbs` stores the parameters as a named tuple, thus we need to flatten the sub_problem_parameters_values
143
133
# and unflatten after the sampling step
144
- variable_sizes[parameter_variable] = Tuple (size (initial_params[parameter_variable]))
145
134
flattened_sub_problem_parameters_values = flatten (sub_problem_parameters_values)
146
135
147
136
sub_state = last (
@@ -158,11 +147,13 @@ function AbstractMCMC.step(
158
147
kwargs... ,
159
148
),
160
149
)
161
- mcmc_states [parameter_variable] = sub_state
150
+ (sub_state, Tuple ( size (initial_params [parameter_variable])))
162
151
end
163
152
164
153
gibbs_state = GibbsState (
165
- initial_params, NamedTuple (mcmc_states), NamedTuple (variable_sizes)
154
+ initial_params,
155
+ NamedTuple {Tuple(keys(sampler.sampler_map))} (mcmc_states),
156
+ NamedTuple {Tuple(keys(sampler.sampler_map))} (variable_sizes),
166
157
)
167
158
trace = update_trace (NamedTuple (), gibbs_state)
168
159
return GibbsTransition (trace), gibbs_state
@@ -176,14 +167,9 @@ function AbstractMCMC.step(
176
167
args... ;
177
168
kwargs... ,
178
169
)
179
- trace = gibbs_state. trace
180
- mcmc_states = gibbs_state. mcmc_states
181
- variable_sizes = gibbs_state. variable_sizes
170
+ (; trace, mcmc_states, variable_sizes) = gibbs_state
182
171
183
- mcmc_states_dict = Dict (
184
- keys (mcmc_states) .=> [mcmc_states[k] for k in keys (mcmc_states)]
185
- )
186
- for parameter_variable in sampler. parameter_names
172
+ mcmc_states = map (keys (sampler. sampler_map)) do parameter_variable
187
173
sub_sampler = sampler. sampler_map[parameter_variable]
188
174
sub_state = mcmc_states[parameter_variable]
189
175
variables_to_be_conditioned_on = setdiff (
@@ -196,7 +182,8 @@ function AbstractMCMC.step(
196
182
logdensity_model. logdensity, conditioning_variables_values
197
183
)
198
184
199
- _, sub_state = AbstractMCMC. logdensity_and_state (cond_logdensity, sub_state)
185
+ logp = LogDensityProblems. logdensity_and_state (cond_logdensity, sub_state)
186
+ sub_state = constructorof (typeof (sub_state))(; logp= logp)
200
187
sub_state = last (
201
188
AbstractMCMC. step (
202
189
rng,
@@ -207,12 +194,10 @@ function AbstractMCMC.step(
207
194
kwargs... ,
208
195
),
209
196
)
210
- mcmc_states_dict[parameter_variable] = sub_state
211
197
trace = update_trace (trace, gibbs_state)
198
+ sub_state
212
199
end
200
+ mcmc_states = NamedTuple {Tuple(keys(sampler.sampler_map))} (mcmc_states)
213
201
214
- mcmc_states = NamedTuple {Tuple(keys(mcmc_states_dict))} (
215
- Tuple ([mcmc_states_dict[k] for k in keys (mcmc_states_dict)])
216
- )
217
202
return GibbsTransition (trace), GibbsState (trace, mcmc_states, variable_sizes)
218
203
end
0 commit comments