@@ -77,7 +77,8 @@ julia> unflatten([1.0,2.0,3.0,4.0,5.0,6.0], (x=(2,2), y=(2,)))
77
77
function unflatten (vec:: AbstractVector , variable_sizes:: NamedTuple )
78
78
result = Dict {Symbol,Array} ()
79
79
start_idx = 1
80
- for (name, size) in pairs (variable_sizes)
80
+ for name in keys (variable_sizes)
81
+ size = variable_sizes[name]
81
82
end_idx = start_idx + prod (size) - 1
82
83
result[name] = reshape (vec[start_idx: end_idx], size... )
83
84
start_idx = end_idx + 1
@@ -100,7 +101,7 @@ function update_trace(trace::NamedTuple, gibbs_state::GibbsState)
100
101
trace = merge (
101
102
trace,
102
103
unflatten (
103
- AbstractMCMC . get_params (sub_state),
104
+ vec (sub_state),
104
105
NamedTuple {(parameter_variable,)} ((
105
106
gibbs_state. variable_sizes[parameter_variable],
106
107
)),
@@ -197,9 +198,14 @@ function AbstractMCMC.step(
197
198
)
198
199
199
200
# recompute the logdensity stored in the mcmc state, because the values might have been updated in other sub-problems
200
- sub_state = AbstractMCMC. recompute_logprob!! (
201
- cond_logdensity, AbstractMCMC. get_params (sub_state), sub_state
202
- )
201
+ updated_log_prob = LogDensityProblems. logdensity (cond_logdensity, sub_state)
202
+
203
+ if ! hasproperty (sub_state, :logp )
204
+ error (
205
+ " $(typeof (sub_state)) does not have a `:logp` field, which is required by Gibbs sampling" ,
206
+ )
207
+ end
208
+ sub_state = BangBang. setproperty!! (sub_state, :logp , updated_log_prob)
203
209
204
210
sub_state = last (
205
211
AbstractMCMC. step (
0 commit comments