Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "AdvancedHMC"
uuid = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d"
version = "0.6.2"
version = "0.6.3"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Expand Down Expand Up @@ -30,7 +30,7 @@ AdvancedHMCMCMCChainsExt = "MCMCChains"
AdvancedHMCOrdinaryDiffEqExt = "OrdinaryDiffEq"

[compat]
AbstractMCMC = "5"
AbstractMCMC = "5.5"
ArgCheck = "1, 2"
CUDA = "3, 4, 5"
DocStringExtensions = "0.8, 0.9"
Expand Down
2 changes: 1 addition & 1 deletion research/tests/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,6 @@ include("../src/riemannian_hmc.jl")
include("relativistic_hmc.jl")
include("riemannian_hmc.jl")

@main function runtests(patterns...; dry::Bool = false)
Comonicon.@main function runtests(patterns...; dry::Bool = false)
retest(patterns...; dry = dry, verbose = Inf)
end
16 changes: 15 additions & 1 deletion src/abstractmcmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,20 @@ getadaptor(state::HMCState) = state.adaptor
getmetric(state::HMCState) = state.metric
getintegrator(state::HMCState) = state.κ.τ.integrator

function AbstractMCMC.getparams(state::HMCState)
return state.transition.z.θ
end

# Using @set to update state.transition.z.θ can lead to inconsistencies:
# - It retains cached log-joint and gradient computations that become invalid
# - This can cause incorrect evaluations in subsequent steps (e.g. MH)
#
# TODO: adopt https://github.com/TuringLang/MCMCTempering.jl/blob/deb96684496f3fbd011b9f70f28c49a161def23f/ext/MCMCTemperingAdvancedHMCExt.jl#L10-L17
# if in the future the interface provides access to the log density function
function AbstractMCMC.setparams!!(state::HMCState, params)
return @set state.transition.z.θ = θ
end

"""
$(TYPEDSIGNATURES)

Expand Down Expand Up @@ -414,4 +428,4 @@ end

function make_kernel(spl::HMCSampler, integrator::AbstractIntegrator)
return spl.κ
end
end
12 changes: 12 additions & 0 deletions test/abstractmcmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,18 @@ using Statistics: mean
LogDensityProblemsAD.ADgradient(Val(:ForwardDiff), ℓπ_gdemo),
)

@testset "getparams and setparams!!" begin
t, s = AbstractMCMC.step(rng, model, nuts;)

θ = AbstractMCMC.getparams(s)
@test θ == t.z.θ
@test AbstractMCMC.setparams!!(s, θ) == s

new_θ = randn(rng, 2)
new_state = AbstractMCMC.setparams!!(s, new_θ)
@test AbstractMCMC.getparams(new_state) == new_θ
end

samples_nuts = AbstractMCMC.sample(
rng,
model,
Expand Down
Loading