1
- using AbstractMCMC, AbstractPPL
2
- using BangBang. ConstructorBase: ConstructorBase
1
+ using AbstractMCMC: AbstractMCMC
2
+ using AbstractPPL: AbstractPPL
3
+ using MCMCChains: Chains
4
+ using Random
3
5
4
6
"""
5
7
Gibbs(sampler_map::NamedTuple)
@@ -99,27 +101,34 @@ function update_trace(trace::NamedTuple, gibbs_state::GibbsState)
99
101
return trace
100
102
end
101
103
104
+ function error_if_not_fully_initialized (
105
+ initial_params:: NamedTuple{ParamNames} , sampler:: Gibbs{<:NamedTuple{SamplerNames}}
106
+ ) where {ParamNames,SamplerNames}
107
+ if Set (ParamNames) != Set (SamplerNames)
108
+ throw (
109
+ ArgumentError (
110
+ " initial_params must contain all parameters in the model, expected $(SamplerNames) , got $(ParamNames) " ,
111
+ ),
112
+ )
113
+ end
114
+ end
115
+
102
116
function AbstractMCMC. step (
103
117
rng:: Random.AbstractRNG ,
104
118
logdensity_model:: AbstractMCMC.LogDensityModel ,
105
- sampler:: Gibbs ,
119
+ sampler:: Gibbs{Tsamplingmap} ,
106
120
args... ;
107
121
initial_params:: NamedTuple ,
108
122
kwargs... ,
109
- )
110
- if Set (keys (initial_params)) != Set (keys (sampler. sampler_map))
111
- throw (
112
- ArgumentError (
113
- " initial_params must contain all parameters in the model, expected $(keys (sampler. sampler_map)) , got $(keys (initial_params)) " ,
114
- ),
115
- )
116
- end
123
+ ) where {Tsamplingmap}
124
+ error_if_not_fully_initialized (initial_params, sampler)
117
125
118
- mcmc_states, variable_sizes = map (keys (sampler. sampler_map)) do parameter_variable
126
+ model_parameter_names = fieldnames (Tsamplingmap)
127
+ results = map (model_parameter_names) do parameter_variable
119
128
sub_sampler = sampler. sampler_map[parameter_variable]
120
129
121
130
variables_to_be_conditioned_on = setdiff (
122
- keys (sampler . sampler_map) , (parameter_variable,)
131
+ model_parameter_names , (parameter_variable,)
123
132
)
124
133
conditioning_variables_values = NamedTuple {Tuple(variables_to_be_conditioned_on)} (
125
134
Tuple ([initial_params[g] for g in variables_to_be_conditioned_on])
@@ -137,7 +146,7 @@ function AbstractMCMC.step(
137
146
AbstractMCMC. step (
138
147
rng,
139
148
AbstractMCMC. LogDensityModel (
140
- AbstractMCMC . condition (
149
+ AbstractPPL . condition (
141
150
logdensity_model. logdensity, conditioning_variables_values
142
151
),
143
152
),
@@ -150,40 +159,46 @@ function AbstractMCMC.step(
150
159
(sub_state, Tuple (size (initial_params[parameter_variable])))
151
160
end
152
161
162
+ mcmc_states = first .(results)
163
+ variable_sizes = last .(results)
164
+
153
165
gibbs_state = GibbsState (
154
166
initial_params,
155
- NamedTuple {Tuple(keys(sampler.sampler_map) )} (mcmc_states),
156
- NamedTuple {Tuple(keys(sampler.sampler_map) )} (variable_sizes),
167
+ NamedTuple {Tuple(model_parameter_names )} (mcmc_states),
168
+ NamedTuple {Tuple(model_parameter_names )} (variable_sizes),
157
169
)
170
+
158
171
trace = update_trace (NamedTuple (), gibbs_state)
159
172
return GibbsTransition (trace), gibbs_state
160
173
end
161
174
175
+ # subsequent steps
162
176
function AbstractMCMC. step (
163
177
rng:: Random.AbstractRNG ,
164
178
logdensity_model:: AbstractMCMC.LogDensityModel ,
165
- sampler:: Gibbs ,
179
+ sampler:: Gibbs{Tsamplingmap} ,
166
180
gibbs_state:: GibbsState ,
167
181
args... ;
168
182
kwargs... ,
169
- )
183
+ ) where {Tsamplingmap}
170
184
(; trace, mcmc_states, variable_sizes) = gibbs_state
171
185
172
- mcmc_states = map (keys (sampler. sampler_map)) do parameter_variable
186
+ model_parameter_names = fieldnames (Tsamplingmap)
187
+ mcmc_states = map (model_parameter_names) do parameter_variable
173
188
sub_sampler = sampler. sampler_map[parameter_variable]
174
189
sub_state = mcmc_states[parameter_variable]
175
190
variables_to_be_conditioned_on = setdiff (
176
- sampler . parameter_names , (parameter_variable,)
191
+ model_parameter_names , (parameter_variable,)
177
192
)
178
193
conditioning_variables_values = NamedTuple {Tuple(variables_to_be_conditioned_on)} (
179
194
Tuple ([trace[g] for g in variables_to_be_conditioned_on])
180
195
)
181
- cond_logdensity = AbstractMCMC . condition (
196
+ cond_logdensity = AbstractPPL . condition (
182
197
logdensity_model. logdensity, conditioning_variables_values
183
198
)
184
199
185
- logp = LogDensityProblems. logdensity_and_state (cond_logdensity, sub_state)
186
- sub_state = constructorof ( typeof ( sub_state))(; logp = logp)
200
+ logp = LogDensityProblems. logdensity (cond_logdensity, sub_state)
201
+ sub_state = ( sub_state)( logp)
187
202
sub_state = last (
188
203
AbstractMCMC. step (
189
204
rng,
@@ -197,7 +212,7 @@ function AbstractMCMC.step(
197
212
trace = update_trace (trace, gibbs_state)
198
213
sub_state
199
214
end
200
- mcmc_states = NamedTuple {Tuple(keys(sampler.sampler_map) )} (mcmc_states)
215
+ mcmc_states = NamedTuple {Tuple(model_parameter_names )} (mcmc_states)
201
216
202
217
return GibbsTransition (trace), GibbsState (trace, mcmc_states, variable_sizes)
203
218
end
0 commit comments