@@ -135,23 +135,6 @@ function SMC(threshold::Real)
135
135
return SMC (AdvancedPS. resample_systematic, threshold)
136
136
end
137
137
138
- struct SMCTransition{T,F<: AbstractFloat } <: AbstractTransition
139
- " The parameters for any given sample."
140
- θ:: T
141
- " The joint log probability of the sample (NOTE: does not work, always set to zero)."
142
- lp:: F
143
- " The weight of the particle the sample was retrieved from."
144
- weight:: F
145
- end
146
-
147
- function SMCTransition (model:: DynamicPPL.Model , vi:: AbstractVarInfo , weight)
148
- theta = getparams (model, vi)
149
- lp = DynamicPPL. getlogjoint_internal (vi)
150
- return SMCTransition (theta, lp, weight)
151
- end
152
-
153
- getstats_with_lp (t:: SMCTransition ) = (lp= t. lp, weight= t. weight)
154
-
155
138
struct SMCState{P,F<: AbstractFloat }
156
139
particles:: P
157
140
particleindex:: Int
@@ -228,7 +211,8 @@ function DynamicPPL.initialstep(
228
211
weight = AdvancedPS. getweight (particles, 1 )
229
212
230
213
# Compute the first transition and the first state.
231
- transition = SMCTransition (model, particle. model. f. varinfo, weight)
214
+ stats = (; weight= weight, logevidence= logevidence)
215
+ transition = Transition (model, particle. model. f. varinfo, stats)
232
216
state = SMCState (particles, 2 , logevidence)
233
217
234
218
return transition, state
@@ -246,7 +230,8 @@ function AbstractMCMC.step(
246
230
weight = AdvancedPS. getweight (particles, index)
247
231
248
232
# Compute the transition and the next state.
249
- transition = SMCTransition (model, particle. model. f. varinfo, weight)
233
+ stats = (; weight= weight, logevidence= state. average_logevidence)
234
+ transition = Transition (model, particle. model. f. varinfo, stats)
250
235
nextstate = SMCState (state. particles, index + 1 , state. average_logevidence)
251
236
252
237
return transition, nextstate
@@ -300,32 +285,28 @@ Equivalent to [`PG`](@ref).
300
285
"""
301
286
const CSMC = PG # type alias of PG as Conditional SMC
302
287
303
- struct PGTransition{T,F<: AbstractFloat } <: AbstractTransition
304
- " The parameters for any given sample."
305
- θ:: T
306
- " The joint log probability of the sample (NOTE: does not work, always set to zero)."
307
- lp:: F
308
- " The log evidence of the sample."
309
- logevidence:: F
310
- end
311
-
312
288
struct PGState
313
289
vi:: AbstractVarInfo
314
290
rng:: Random.AbstractRNG
315
291
end
316
292
317
293
get_varinfo (state:: PGState ) = state. vi
318
294
319
- function PGTransition (model:: DynamicPPL.Model , vi:: AbstractVarInfo , logevidence)
320
- theta = getparams (model, vi)
321
- lp = DynamicPPL. getlogjoint_internal (vi)
322
- return PGTransition (theta, lp, logevidence)
323
- end
324
-
325
- getstats_with_lp (t:: PGTransition ) = (lp= t. lp, logevidence= t. logevidence)
326
-
327
- function getlogevidence (samples, sampler:: Sampler{<:PG} , state:: PGState )
328
- return mean (x. logevidence for x in samples)
295
+ function getlogevidence (
296
+ transitions:: AbstractVector{<:Turing.Inference.Transition} ,
297
+ sampler:: Sampler{<:PG} ,
298
+ state:: PGState ,
299
+ )
300
+ logevidences = map (transitions) do t
301
+ if haskey (t. stat, :logevidence )
302
+ return t. stat. logevidence
303
+ else
304
+ # This should not really happen, but if it does we can handle it
305
+ # gracefully
306
+ return missing
307
+ end
308
+ end
309
+ return mean (logevidences)
329
310
end
330
311
331
312
function DynamicPPL. initialstep (
@@ -357,7 +338,7 @@ function DynamicPPL.initialstep(
357
338
358
339
# Compute the first transition.
359
340
_vi = reference. model. f. varinfo
360
- transition = PGTransition (model, _vi, logevidence)
341
+ transition = Transition (model, _vi, (; logevidence= logevidence) )
361
342
362
343
return transition, PGState (_vi, reference. rng)
363
344
end
@@ -397,7 +378,7 @@ function AbstractMCMC.step(
397
378
398
379
# Compute the transition.
399
380
_vi = newreference. model. f. varinfo
400
- transition = PGTransition (model, _vi, logevidence)
381
+ transition = Transition (model, _vi, (; logevidence= logevidence) )
401
382
402
383
return transition, PGState (_vi, newreference. rng)
403
384
end
0 commit comments