Skip to content

Commit 20f9e97

Browse files
committed
More test fixes
1 parent c09c2a5 commit 20f9e97

File tree

3 files changed

+15
-2
lines changed

3 files changed

+15
-2
lines changed

src/mcmc/emcee.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,12 +35,15 @@ end
3535
_get_n_walkers(e::Emcee) = e.ensemble.n_walkers
3636
_get_n_walkers(spl::Sampler{<:Emcee}) = _get_n_walkers(spl.alg)
3737

38+
# Because Emcee expects n_walkers initialisations, we need to override this
39+
DynamicPPL.init_strategy(spl::Sampler{<:Emcee}) = fill(InitFromPrior(), _get_n_walkers(spl))
40+
3841
function AbstractMCMC.step(
3942
rng::Random.AbstractRNG,
4043
model::Model,
4144
spl::Sampler{<:Emcee};
4245
resume_from=nothing,
43-
initial_params=fill(DynamicPPL.init_strategy(spl), _get_n_walkers(spl)),
46+
initial_params,
4447
kwargs...,
4548
)
4649
if resume_from !== nothing

src/mcmc/is.jl

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,9 @@ struct ISContext{R<:AbstractRNG} <: DynamicPPL.AbstractContext
5151
end
5252
DynamicPPL.NodeTrait(::ISContext) = DynamicPPL.IsLeaf()
5353

54-
function DynamicPPL.tilde_assume!!(ctx::ISContext, dist::Distribution, vn::VarName, vi)
54+
function DynamicPPL.tilde_assume!!(
55+
ctx::ISContext, dist::Distribution, vn::VarName, vi::AbstractVarInfo
56+
)
5557
if haskey(vi, vn)
5658
r = vi[vn]
5759
else
@@ -61,3 +63,8 @@ function DynamicPPL.tilde_assume!!(ctx::ISContext, dist::Distribution, vn::VarNa
6163
vi = DynamicPPL.accumulate_assume!!(vi, r, 0.0, vn, dist)
6264
return r, vi
6365
end
66+
function DynamicPPL.tilde_observe!!(
67+
::ISContext, right::Distribution, left, vn::Union{VarName,Nothing}, vi::AbstractVarInfo
68+
)
69+
return DynamicPPL.tilde_observe!!(DefaultContext(), right, left, vn, vi)
70+
end

src/mcmc/particle_mcmc.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,7 @@ function AbstractMCMC.sample(
144144
N::Integer;
145145
chain_type=DynamicPPL.default_chain_type(sampler),
146146
resume_from=nothing,
147+
initial_params=DynamicPPL.init_strategy(sampler),
147148
initial_state=DynamicPPL.loadstate(resume_from),
148149
progress=PROGRESS[],
149150
kwargs...,
@@ -155,6 +156,7 @@ function AbstractMCMC.sample(
155156
sampler,
156157
N;
157158
chain_type=chain_type,
159+
initial_params=initial_params,
158160
progress=progress,
159161
nparticles=N,
160162
kwargs...,
@@ -166,6 +168,7 @@ function AbstractMCMC.sample(
166168
sampler,
167169
N;
168170
chain_type,
171+
initial_params=initial_params,
169172
initial_state,
170173
progress=progress,
171174
nparticles=N,

0 commit comments

Comments
 (0)