Skip to content

Commit 9bc58c8

Browse files
committed
[no ci] Fix pMCMC
1 parent 7e522a6 commit 9bc58c8

File tree

2 files changed

+12
-22
lines changed

2 files changed

+12
-22
lines changed

src/mcmc/gibbs.jl

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -353,19 +353,9 @@ This is straight up copypasta from DynamicPPL's src/sampler.jl. It is repeated h
353353
support calling both step and step_warmup as the initial step. DynamicPPL initialstep is
354354
incompatible with step_warmup.
355355
"""
356-
function initial_varinfo(rng, model, spl, initial_params)
356+
function initial_varinfo(rng, model, spl, initial_params::DynamicPPL.AbstractInitStrategy)
357357
vi = DynamicPPL.default_varinfo(rng, model, spl)
358-
359-
# Update the parameters if provided.
360-
if initial_params !== nothing
361-
vi = DynamicPPL.initialize_parameters!!(vi, initial_params, model)
362-
363-
# Update joint log probability.
364-
# This is a quick fix for https://github.com/TuringLang/Turing.jl/issues/1588
365-
# and https://github.com/TuringLang/Turing.jl/issues/1563
366-
# to avoid that existing variables are resampled
367-
vi = last(DynamicPPL.evaluate!!(model, vi))
368-
end
358+
_, vi = DynamicPPL.init!!(rng, model, vi, initial_params)
369359
return vi
370360
end
371361

src/mcmc/particle_mcmc.jl

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,10 @@ function unset_all_del!(vi::AbstractVarInfo)
3636
return nothing
3737
end
3838

39-
# TODO(penelopeysm / DPPL 0.38): Figure this out
4039
struct ParticleMCMCContext{R<:AbstractRNG} <: DynamicPPL.AbstractContext
4140
rng::R
4241
end
42+
DynamicPPL.NodeTrait(::ParticleMCMCContext) = DynamicPPL.IsLeaf()
4343

4444
struct TracedModel{V<:AbstractVarInfo,M<:Model,E<:Tuple} <: AdvancedPS.AbstractGenericModel
4545
model::M
@@ -75,8 +75,7 @@ function AdvancedPS.delete_retained!(trace::TracedModel)
7575
# In such a case, we need to ensure that when we continue sampling (i.e.
7676
# the next time we hit tilde_assume!!), we don't use the values in the
7777
# reference particle but rather sample new values.
78-
trace = Accessors.@set trace.resample = true
79-
return trace
78+
return TracedModel(trace.model, trace.varinfo, trace.evaluator, true)
8079
end
8180

8281
function AdvancedPS.reset_model(trace::TracedModel)
@@ -309,8 +308,6 @@ function DynamicPPL.initialstep(
309308
kwargs...,
310309
)
311310
vi = DynamicPPL.setacc!!(vi, ProduceLogLikelihoodAccumulator())
312-
# Reset the VarInfo before new sweep
313-
set_all_del!(vi)
314311

315312
# Create a new set of particles
316313
num_particles = spl.alg.nparticles
@@ -348,9 +345,6 @@ function AbstractMCMC.step(
348345
# Create reference particle for which the samples will be retained.
349346
reference = AdvancedPS.forkr(AdvancedPS.Trace(model, vi, state.rng, false))
350347

351-
# For all other particles, do not retain the variables but resample them.
352-
set_all_del!(vi)
353-
354348
# Create a new set of particles.
355349
num_particles = spl.alg.nparticles
356350
x = map(1:num_particles) do i
@@ -410,7 +404,7 @@ function get_trace_local_resampled_maybe(fallback_resampled::Bool)
410404
catch e
411405
e == KeyError(:task_variable) ? nothing : rethrow(e)
412406
end
413-
return (trace === nothing ? fallback_resampled : trace.resample)::Bool
407+
return (trace === nothing ? fallback_resampled : trace.model.f.resample)::Bool
414408
end
415409

416410
"""
@@ -479,7 +473,13 @@ function DynamicPPL.tilde_assume!!(
479473
return x, vi
480474
end
481475

482-
function DynamicPPL.tilde_observe!!(::ParticleMCMCContext, right, left, vn, vi)
476+
function DynamicPPL.tilde_observe!!(
477+
::ParticleMCMCContext,
478+
right::Distribution,
479+
left,
480+
vn::Union{VarName,Nothing},
481+
vi::AbstractVarInfo,
482+
)
483483
arg_vi_id = objectid(vi)
484484
vi = get_trace_local_varinfo_maybe(vi)
485485
using_local_vi = objectid(vi) == arg_vi_id

0 commit comments

Comments
 (0)