Skip to content

Commit c09c2a5

Browse files
committed
Fix more tests
1 parent ed197f9 commit c09c2a5

File tree

4 files changed

+19
-21
lines changed

4 files changed

+19
-21
lines changed

src/mcmc/is.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,10 @@ function getlogevidence(samples::Vector{<:Transition}, ::Sampler{<:IS}, state)
4646
return logsumexp(map(x -> x.loglikelihood, samples)) - log(length(samples))
4747
end
4848

49-
struct ISContext{R<:AbstractRNG}
49+
struct ISContext{R<:AbstractRNG} <: DynamicPPL.AbstractContext
5050
rng::R
5151
end
52+
DynamicPPL.NodeTrait(::ISContext) = DynamicPPL.IsLeaf()
5253

5354
function DynamicPPL.tilde_assume!!(ctx::ISContext, dist::Distribution, vn::VarName, vi)
5455
if haskey(vi, vn)

src/mcmc/mh.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -415,3 +415,8 @@ function DynamicPPL.tilde_assume!!(
415415
end
416416
return DynamicPPL.tilde_assume!!(dispatch_ctx, right, vn, vi)
417417
end
418+
function DynamicPPL.tilde_observe!!(
419+
::MHContext, right::Distribution, left, vn::Union{VarName,Nothing}, vi::AbstractVarInfo
420+
)
421+
return DynamicPPL.tilde_observe!!(DefaultContext(), right, left, vn, vi)
422+
end

test/ad.jl

Lines changed: 11 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -154,31 +154,23 @@ end
154154
# context, and then call check_adtype on the result before returning the results from the
155155
# child context.
156156

157-
function DynamicPPL.tilde_assume(context::ADTypeCheckContext, right, vn, vi)
158-
value, vi = DynamicPPL.tilde_assume(DynamicPPL.childcontext(context), right, vn, vi)
159-
check_adtype(context, vi)
160-
return value, vi
161-
end
162-
163-
function DynamicPPL.tilde_assume(
164-
rng::Random.AbstractRNG, context::ADTypeCheckContext, sampler, right, vn, vi
157+
function DynamicPPL.tilde_assume!!(
158+
context::ADTypeCheckContext, right::Distribution, vn::VarName, vi::AbstractVarInfo
165159
)
166-
value, vi = DynamicPPL.tilde_assume(
167-
rng, DynamicPPL.childcontext(context), sampler, right, vn, vi
168-
)
160+
value, vi = DynamicPPL.tilde_assume!!(DynamicPPL.childcontext(context), right, vn, vi)
169161
check_adtype(context, vi)
170162
return value, vi
171163
end
172164

173-
function DynamicPPL.tilde_observe!!(context::ADTypeCheckContext, right, left, vi)
174-
left, vi = DynamicPPL.tilde_observe!!(DynamicPPL.childcontext(context), right, left, vi)
175-
check_adtype(context, vi)
176-
return left, vi
177-
end
178-
179-
function DynamicPPL.tilde_observe!!(context::ADTypeCheckContext, sampler, right, left, vi)
165+
function DynamicPPL.tilde_observe!!(
166+
context::ADTypeCheckContext,
167+
right::Distribution,
168+
left,
169+
vn::Union{VarName,Nothing},
170+
vi::AbstractVarInfo,
171+
)
180172
left, vi = DynamicPPL.tilde_observe!!(
181-
DynamicPPL.childcontext(context), sampler, right, left, vi
173+
DynamicPPL.childcontext(context), right, left, vn, vi
182174
)
183175
check_adtype(context, vi)
184176
return left, vi

test/mcmc/gibbs.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -695,7 +695,7 @@ end
695695
# Determine initial parameters to make comparison as fair as possible.
696696
# posterior_mean returns a NamedTuple so we can plug it in directly.
697697
posterior_mean = DynamicPPL.TestUtils.posterior_mean(model)
698-
initial_params = fill(InitFromParams(initial_params), num_chains)
698+
initial_params = fill(InitFromParams(posterior_mean), num_chains)
699699

700700
# Sampler to use for Gibbs components.
701701
hmc = HMC(0.1, 32)

0 commit comments

Comments
 (0)