Skip to content

Commit a1837b5

Browse files
committed
Fix Prior(), fix a couple more imports
1 parent 3d44c12 commit a1837b5

File tree

3 files changed

+15
-9
lines changed

3 files changed

+15
-9
lines changed

src/mcmc/gibbs.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -549,7 +549,8 @@ function setparams_varinfo!!(
549549
# update its logprob. To do this, we have to call evaluate!! with the sampler, rather
550550
# than just a context, because ESS is peculiar in how it uses LikelihoodContext for
551551
# some variables and DefaultContext for others.
552-
return last(DynamicPPL.evaluate!!(model, params, SamplingContext(sampler)))
552+
# TODO(penelopeysm): Is this still needed?
553+
return last(DynamicPPL.evaluate!!(model, params, DynamicPPL.SamplingContext(sampler)))
553554
end
554555

555556
function setparams_varinfo!!(
@@ -559,7 +560,7 @@ function setparams_varinfo!!(
559560
params::AbstractVarInfo,
560561
)
561562
logdensity = DynamicPPL.LogDensityFunction(
562-
model, state.ldf.varinfo; adtype=sampler.alg.adtype
563+
model, DynamicPPL.getlogjoint, state.ldf.varinfo; adtype=sampler.alg.adtype
563564
)
564565
new_inner_state = setparams_varinfo!!(
565566
AbstractMCMC.LogDensityModel(logdensity), sampler, state.state, params

src/mcmc/particle_mcmc.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ function TracedModel(
1818
varinfo::AbstractVarInfo,
1919
rng::Random.AbstractRNG,
2020
)
21-
context = SamplingContext(rng, sampler, DefaultContext())
21+
context = DynamicPPL.SamplingContext(rng, sampler, DefaultContext())
2222
args, kwargs = DynamicPPL.make_evaluate_args_and_kwargs(model, varinfo, context)
2323
if kwargs !== nothing && !isempty(kwargs)
2424
error(
@@ -395,7 +395,7 @@ function AbstractMCMC.step(
395395
end
396396

397397
function DynamicPPL.use_threadsafe_eval(
398-
::SamplingContext{<:Sampler{<:Union{PG,SMC}}}, ::AbstractVarInfo
398+
::DynamicPPL.SamplingContext{<:Sampler{<:Union{PG,SMC}}}, ::AbstractVarInfo
399399
)
400400
return false
401401
end
@@ -457,7 +457,9 @@ end
457457
# end
458458

459459
function DynamicPPL.acclogp!!(
460-
context::SamplingContext{<:Sampler{<:Union{PG,SMC}}}, varinfo::AbstractVarInfo, logp
460+
context::DynamicPPL.SamplingContext{<:Sampler{<:Union{PG,SMC}}},
461+
varinfo::AbstractVarInfo,
462+
logp,
461463
)
462464
varinfo_trace = trace_local_varinfo_maybe(varinfo)
463465
return DynamicPPL.acclogp!!(DynamicPPL.childcontext(context), varinfo_trace, logp)

src/mcmc/prior.jl

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,17 @@ function AbstractMCMC.step(
1212
state=nothing;
1313
kwargs...,
1414
)
15+
# TODO(DPPL0.37/penelopeysm): replace with init!! instead
1516
vi = last(
1617
DynamicPPL.evaluate!!(
17-
model,
18-
VarInfo(),
19-
SamplingContext(rng, DynamicPPL.SampleFromPrior(), DynamicPPL.PriorContext()),
18+
model, VarInfo(), DynamicPPL.SamplingContext(rng, DynamicPPL.SampleFromPrior())
2019
),
2120
)
22-
return vi, nothing
21+
# Need to manually construct the Transition here because we only
22+
# want to use the prior probability.
23+
xs = Turing.Inference.getparams(model, vi)
24+
lp = DynamicPPL.getlogprior(vi)
25+
return Transition(xs, lp, nothing)
2326
end
2427

2528
DynamicPPL.default_chain_type(sampler::Prior) = MCMCChains.Chains

0 commit comments

Comments
 (0)