Skip to content

Commit 283d4dd

Browse files
committed
more fix fix fix
1 parent 64ebd92 commit 283d4dd

File tree

2 files changed

+15
-8
lines changed

2 files changed

+15
-8
lines changed

src/mcmc/is.jl

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,14 +31,20 @@ DynamicPPL.initialsampler(sampler::Sampler{<:IS}) = sampler
3131
function DynamicPPL.initialstep(
3232
rng::AbstractRNG, model::Model, spl::Sampler{<:IS}, vi::AbstractVarInfo; kwargs...
3333
)
34-
return Transition(model, vi), nothing
34+
# Need to manually construct the Transition here because we only
35+
# want to use the likelihood.
36+
xs = Turing.Inference.getparams(model, vi)
37+
lp = DynamicPPL.getloglikelihood(vi)
38+
return Transition(xs, lp, nothing), nothing
3539
end
3640

3741
function AbstractMCMC.step(
3842
rng::Random.AbstractRNG, model::Model, spl::Sampler{<:IS}, ::Nothing; kwargs...
3943
)
4044
vi = VarInfo(rng, model, spl)
41-
return Transition(model, vi), nothing
45+
xs = Turing.Inference.getparams(model, vi)
46+
lp = DynamicPPL.getloglikelihood(vi)
47+
return Transition(xs, lp, nothing), nothing
4248
end
4349

4450
# Calculate evidence.
@@ -53,5 +59,6 @@ function DynamicPPL.assume(rng, ::Sampler{<:IS}, dist::Distribution, vn::VarName
5359
r = rand(rng, dist)
5460
vi = push!!(vi, vn, r, dist)
5561
end
56-
return r, 0, vi
62+
vi = accumulate_assume!!(vi, r, 0.0, vn, dist)
63+
return r, vi
5764
end

src/mcmc/particle_mcmc.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,9 @@ function AdvancedPS.advance!(
3535
trace::AdvancedPS.Trace{<:AdvancedPS.LibtaskModel{<:TracedModel}}, isref::Bool=false
3636
)
3737
# Make sure we load/reset the rng in the new replaying mechanism
38-
trace.model.f.varinfo = DynamicPPL.increment_num_produce!!(trace.model.f.varinfo)
38+
trace = Accessors.@set trace.model.f.varinfo = DynamicPPL.increment_num_produce!!(
39+
trace.model.f.varinfo
40+
)
3941
isref ? AdvancedPS.load_state!(trace.rng) : AdvancedPS.save_state!(trace.rng)
4042
score = consume(trace.model.ctask)
4143
if score === nothing
@@ -51,13 +53,11 @@ function AdvancedPS.delete_retained!(trace::TracedModel)
5153
end
5254

5355
function AdvancedPS.reset_model(trace::TracedModel)
54-
DynamicPPL.reset_num_produce!(trace.varinfo)
55-
return trace
56+
return Accessors.@set trace.varinfo = DynamicPPL.reset_num_produce!!(trace.varinfo)
5657
end
5758

5859
function AdvancedPS.reset_logprob!(trace::TracedModel)
59-
DynamicPPL.resetlogp!!(trace.model.varinfo)
60-
return trace
60+
return Accessors.@set trace.model.varinfo = DynamicPPL.resetlogp!!(trace.model.varinfo)
6161
end
6262

6363
function AdvancedPS.update_rng!(

0 commit comments

Comments
 (0)