Skip to content

Commit c7f3163

Browse files
committed
update implementation
1 parent b9dfa36 commit c7f3163

File tree

1 file changed

+6
-8
lines changed

1 file changed

+6
-8
lines changed

src/abstractmcmc.jl

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -34,14 +34,12 @@ function AbstractMCMC.getparams(state::HMCState)
3434
return state.transition.z.θ
3535
end
3636

37-
# Using @set to update state.transition.z.θ can lead to inconsistencies:
38-
# - It retains cached log-joint and gradient computations that become invalid
39-
# - This can cause incorrect evaluations in subsequent steps (e.g. MH)
40-
#
41-
# TODO: adopt https://github.com/TuringLang/MCMCTempering.jl/blob/deb96684496f3fbd011b9f70f28c49a161def23f/ext/MCMCTemperingAdvancedHMCExt.jl#L10-L17
42-
# if in the future the interface provides access to the log density function
43-
function AbstractMCMC.setparams!!(state::HMCState, params)
44-
return @set state.transition.z.θ = params
37+
function AbstractMCMC.setparams!!(model, state::HMCState, params)
38+
hamiltonian = AdvancedHMC.Hamiltonian(state.metric, model)
39+
return Setfield.@set state.transition.z = AdvancedHMC.phasepoint(
40+
hamiltonian, params, state.transition.z.r;
41+
ℓκ=state.transition.z.ℓκ
42+
)
4543
end
4644

4745
"""

0 commit comments

Comments
 (0)