Skip to content

Commit 145a52b

Browse files
committed
Implement AbstractMCMC.getstats
1 parent 6ac07d5 commit 145a52b

File tree

3 files changed

+16
-1
lines changed

3 files changed

+16
-1
lines changed

src/AdvancedMH.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,9 @@ end
146146
function AbstractMCMC.getparams(t::Transition)
147147
return t.params
148148
end
149+
function AbstractMCMC.getstats(t::Transition)
150+
return (accepted=t.accepted,)
151+
end
149152

150153
# TODO (sunxd): remove `DensityModel` in favor of `AbstractMCMC.LogDensityModel`
151154
function AbstractMCMC.setparams!!(model::DensityModelOrLogDensityModel, t::Transition, params)

src/MALA.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@ logdensity(model::DensityModelOrLogDensityModel, t::GradientTransition) = t.lp
2323
function AbstractMCMC.getparams(t::GradientTransition)
2424
return t.params
2525
end
26+
function AbstractMCMC.getstats(t::GradientTransition)
27+
return (accepted=t.accepted,)
28+
end
2629

2730
function AbstractMCMC.setparams!!(model::DensityModelOrLogDensityModel, t::GradientTransition, params)
2831
lp, gradient = logdensity_and_gradient(model, params)

src/RobustAdaptiveMetropolis.jl

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,9 +116,18 @@ end
116116
AbstractMCMC.getparams(state::RobustAdaptiveMetropolisState) = state.x
117117
function AbstractMCMC.setparams!!(state::RobustAdaptiveMetropolisState, x)
118118
return RobustAdaptiveMetropolisState(
119-
x, state.logprob, state.S, state.logα, state.η, state.iteration, state.isaccept
119+
x,
120+
state.logprob,
121+
state.S,
122+
state.logα,
123+
state.η,
124+
state.iteration,
125+
state.isaccept,
120126
)
121127
end
128+
function AbstractMCMC.getstats(state::RobustAdaptiveMetropolisState)
129+
return (logα = state.logα, η = state.η, accepted = state.isaccept)
130+
end
122131

123132
function ram_step_inner(
124133
rng::Random.AbstractRNG,

0 commit comments

Comments
 (0)