Skip to content

Commit 57f9e2b

Browse files
committed
add comment
1 parent 8f801d4 commit 57f9e2b

File tree

1 file changed

+8
-9
lines changed

1 file changed

+8
-9
lines changed

src/abstractmcmc.jl

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -31,18 +31,17 @@ getmetric(state::HMCState) = state.metric
3131
getintegrator(state::HMCState) = state.κ.τ.integrator
3232

3333
function AbstractMCMC.getparams(state::HMCState)
34-
# TODO(sunxd): should we return a copy?
3534
return state.transition.z.θ
3635
end
3736

37+
# Using @set to update state.transition.z.θ can lead to inconsistencies:
38+
# - It retains cached log-joint and gradient computations that become invalid
39+
# - This can cause incorrect evaluations in subsequent steps (e.g. MH)
40+
#
41+
# TODO: adopt https://github.com/TuringLang/MCMCTempering.jl/blob/deb96684496f3fbd011b9f70f28c49a161def23f/ext/MCMCTemperingAdvancedHMCExt.jl#L10-L17
42+
# if in the future the interface provides access to the log density function
3843
function AbstractMCMC.setparams!!(state::HMCState, params)
39-
hamiltonian = AdvancedHMC.Hamiltonian(state.metric, model)
40-
return Setfield.@set state.transition.z = AdvancedHMC.phasepoint(
41-
hamiltonian,
42-
params,
43-
state.transition.z.r;
44-
ℓκ = state.transition.z.ℓκ,
45-
)
44+
return @set state.transition.z.θ = θ
4645
end
4746

4847
"""
@@ -429,4 +428,4 @@ end
429428

430429
function make_kernel(spl::HMCSampler, integrator::AbstractIntegrator)
431430
return spl.κ
432-
end
431+
end

0 commit comments

Comments
 (0)