Skip to content

Commit 6bdf29d

Browse files
committed
fix externalsampler in gibbs
1 parent 694aa60 commit 6bdf29d

File tree

2 files changed

+17
-7
lines changed

2 files changed

+17
-7
lines changed

src/mcmc/external_sampler.jl

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -83,14 +83,20 @@ function externalsampler(
8383
return ExternalSampler(sampler, adtype, Val(unconstrained))
8484
end
8585

86-
struct TuringState{S,M,V}
86+
# TODO(penelopeysm): Can't we clean this up somehow?
87+
struct TuringState{S,V1,M,V}
8788
state::S
89+
# Note that this varinfo must have the correct parameters set; but logp
90+
# does not matter as it will be re-evaluated
91+
varinfo::V1
8892
# Note that in general the VarInfo inside this LogDensityFunction will have
8993
# junk parameters and logp. It only exists to provide structure
9094
ldf::DynamicPPL.LogDensityFunction{M,V}
9195
end
9296

93-
get_varinfo(state::TuringState) = state.ldf.varinfo
97+
# get_varinfo should return something from which the correct parameters can be
98+
# obtained, hence we use state.varinfo rather than state.ldf.varinfo
99+
get_varinfo(state::TuringState) = state.varinfo
94100
get_varinfo(state::AbstractVarInfo) = state
95101

96102
getparams(::DynamicPPL.Model, transition::AdvancedHMC.Transition) = transition.z.θ
@@ -148,8 +154,10 @@ function AbstractMCMC.step(
148154
end
149155

150156
new_parameters = getparams(f.model, state_inner)
151-
vi = DynamicPPL.unflatten(f.varinfo, new_parameters)
152-
return (Transition(f.model, vi, transition_inner), TuringState(state_inner, f))
157+
new_vi = DynamicPPL.unflatten(f.varinfo, new_parameters)
158+
return (
159+
Transition(f.model, new_vi, transition_inner), TuringState(state_inner, new_vi, f)
160+
)
153161
end
154162

155163
function AbstractMCMC.step(
@@ -168,6 +176,8 @@ function AbstractMCMC.step(
168176
)
169177

170178
new_parameters = getparams(f.model, state_inner)
171-
vi = DynamicPPL.unflatten(f.varinfo, new_parameters)
172-
return (Transition(f.model, vi, transition_inner), TuringState(state_inner, f))
179+
new_vi = DynamicPPL.unflatten(f.varinfo, new_parameters)
180+
return (
181+
Transition(f.model, new_vi, transition_inner), TuringState(state_inner, new_vi, f)
182+
)
173183
end

src/mcmc/gibbs.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -573,7 +573,7 @@ function setparams_varinfo!!(
573573
new_inner_state = setparams_varinfo!!(
574574
AbstractMCMC.LogDensityModel(logdensity), sampler, state.state, params
575575
)
576-
return TuringState(new_inner_state, logdensity)
576+
return TuringState(new_inner_state, params, logdensity)
577577
end
578578

579579
function setparams_varinfo!!(

0 commit comments

Comments
 (0)