Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "AdvancedHMC"
uuid = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d"
version = "0.6.2"
version = "0.6.3"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Expand Down Expand Up @@ -30,7 +30,7 @@ AdvancedHMCMCMCChainsExt = "MCMCChains"
AdvancedHMCOrdinaryDiffEqExt = "OrdinaryDiffEq"

[compat]
AbstractMCMC = "5"
AbstractMCMC = "5.5"
ArgCheck = "1, 2"
CUDA = "3, 4, 5"
DocStringExtensions = "0.8, 0.9"
Expand Down
9 changes: 9 additions & 0 deletions src/abstractmcmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,15 @@ getadaptor(state::HMCState) = state.adaptor
getmetric(state::HMCState) = state.metric
getintegrator(state::HMCState) = state.κ.τ.integrator

function AbstractMCMC.getparams(state::HMCState)
# TODO(sunxd): should we return a copy?
return state.transition.z.θ
end

function AbstractMCMC.setparams!!(state::HMCState, θ)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One dangerous aspect there is that state.transition caches the logjoint and gradient computations. So when you use the @set macro here, you're going to also keep the cached log-joint and gradient computation, which will then be out of sync with the parameters.

If you then naively pass this transition somewhere, say, into the next step call, AHMC.jl will use the incorrect logjoint eval in the MH step.

IMO the safe way is to use the explicit constructor of PhasePoint I believe without passing in the cached values. This should result in receomputation of this.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just realized that model is not passed into setparams!!, I think for now, we only set the parameters, then when use the transition, logp should be recomputed. (we can also later introduce some like setlogp or compute_logp etc.) I'll add a comment in the code

Copy link
Member

@yebai yebai Oct 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed with @torfjelde's concern. Can we add logdensitymodel to

  • setparams!!(state, logdensitymodel, params)
  • getparam(state, logdensitymodel).

where logdensitymodel follows the LogDensityProblem interface. This would allow us to recompute the model's log probability (i.e., recompute_logp) inside these setparam!! functions on demand, which Turing's new Gibbs sampler uses.

return @set state.transition.z.θ = θ
end

"""
$(TYPEDSIGNATURES)

Expand Down
12 changes: 12 additions & 0 deletions test/abstractmcmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,18 @@ using Statistics: mean
LogDensityProblemsAD.ADgradient(Val(:ForwardDiff), ℓπ_gdemo),
)

@testset "getparams and setparams!!" begin
t, s = AbstractMCMC.step(rng, model, nuts;)

θ = AbstractMCMC.getparams(s)
@test θ == t.z.θ
@test AbstractMCMC.setparams!!(s, θ) == s

new_θ = randn(rng, 2)
new_state = AbstractMCMC.setparams!!(s, new_θ)
@test AbstractMCMC.getparams(new_state) == new_θ
end

samples_nuts = AbstractMCMC.sample(
rng,
model,
Expand Down
Loading