Skip to content

Commit 3636a14

Browse files
committed
Rename varinfo -> get_varinfo
1 parent d1f13a9 commit 3636a14

File tree

3 files changed

+13
-13
lines changed

3 files changed

+13
-13
lines changed

src/mcmc/external_sampler.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,8 +90,8 @@ struct TuringState{S,M,V}
9090
ldf::DynamicPPL.LogDensityFunction{M,V}
9191
end
9292

93-
varinfo(state::TuringState) = state.ldf.varinfo
94-
varinfo(state::AbstractVarInfo) = state
93+
get_varinfo(state::TuringState) = state.ldf.varinfo
94+
get_varinfo(state::AbstractVarInfo) = state
9595

9696
getparams(::DynamicPPL.Model, transition::AdvancedHMC.Transition) = transition.z.θ
9797
function getparams(model::DynamicPPL.Model, state::AdvancedHMC.HMCState)

src/mcmc/gibbs.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -342,7 +342,7 @@ struct GibbsState{V<:DynamicPPL.AbstractVarInfo,S}
342342
states::S
343343
end
344344

345-
varinfo(state::GibbsState) = state.vi
345+
get_varinfo(state::GibbsState) = state.vi
346346

347347
"""
348348
Initialise a VarInfo for the Gibbs sampler.
@@ -464,7 +464,7 @@ function gibbs_initialstep_recursive(
464464
initial_params=initial_params_local,
465465
kwargs...,
466466
)
467-
new_vi_local = varinfo(new_state)
467+
new_vi_local = get_varinfo(new_state)
468468
# Merge in any new variables that were introduced during the step, but that
469469
# were not in the domain of the current sampler.
470470
vi = merge(vi, get_global_varinfo(context))
@@ -492,7 +492,7 @@ function AbstractMCMC.step(
492492
state::GibbsState;
493493
kwargs...,
494494
)
495-
vi = varinfo(state)
495+
vi = get_varinfo(state)
496496
alg = spl.alg
497497
varnames = alg.varnames
498498
samplers = alg.samplers
@@ -512,7 +512,7 @@ function AbstractMCMC.step_warmup(
512512
state::GibbsState;
513513
kwargs...,
514514
)
515-
vi = varinfo(state)
515+
vi = get_varinfo(state)
516516
alg = spl.alg
517517
varnames = alg.varnames
518518
samplers = alg.samplers
@@ -607,7 +607,7 @@ state for this sampler. This is relevant when multilple samplers are sampling th
607607
variables, and one might need it to be linked while the other doesn't.
608608
"""
609609
function match_linking!!(varinfo_local, prev_state_local, model)
610-
prev_varinfo_local = varinfo(prev_state_local)
610+
prev_varinfo_local = get_varinfo(prev_state_local)
611611
was_linked = DynamicPPL.istrans(prev_varinfo_local)
612612
is_linked = DynamicPPL.istrans(varinfo_local)
613613
if was_linked && !is_linked
@@ -689,7 +689,7 @@ function gibbs_step_recursive(
689689
# Take a step with the local sampler.
690690
new_state = last(step_function(rng, conditioned_model, sampler, state; kwargs...))
691691

692-
new_vi_local = varinfo(new_state)
692+
new_vi_local = get_varinfo(new_state)
693693
# Merge the latest values for all the variables in the current sampler.
694694
new_global_vi = merge(get_global_varinfo(context), new_vi_local)
695695
new_global_vi = setlogp!!(new_global_vi, getlogp(new_vi_local))

test/mcmc/gibbs.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -822,11 +822,11 @@ end
822822
]
823823
@testset "$(sampler_inner)" for sampler_inner in samplers_inner
824824
sampler = Gibbs(@varname(m1) => sampler_inner, @varname(m2) => sampler_inner)
825-
# Random.seed!(42)
826-
# chain = sample(
827-
# model, sampler, 1000; discard_initial=1000, thinning=10, n_adapts=0
828-
# )
829-
# check_numerical(chain, [:m1, :m2], [-0.2, 0.6]; atol=0.1)
825+
Random.seed!(42)
826+
chain = sample(
827+
model, sampler, 1000; discard_initial=1000, thinning=10, n_adapts=0
828+
)
829+
check_numerical(chain, [:m1, :m2], [-0.2, 0.6]; atol=0.1)
830830
check_logp_correct(sampler_inner)
831831
end
832832
end

0 commit comments

Comments
 (0)