Skip to content

Commit 4764120

Browse files
Tor FjeldeTor Fjelde
authored andcommitted
formatting
1 parent 9247281 commit 4764120

File tree

1 file changed

+65
-22
lines changed

1 file changed

+65
-22
lines changed

src/RobustAdaptiveMetropolis.jl

Lines changed: 65 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -70,17 +70,18 @@ true
7070
# References
7171
[^VIH12]: Vihola (2012) Robust adaptive Metropolis algorithm with coerced acceptance rate, Statistics and computing.
7272
"""
73-
Base.@kwdef struct RobustAdaptiveMetropolis{T,A<:Union{Nothing,AbstractMatrix{T}}} <: AdvancedMH.MHSampler
73+
Base.@kwdef struct RobustAdaptiveMetropolis{T,A<:Union{Nothing,AbstractMatrix{T}}} <:
74+
AdvancedMH.MHSampler
7475
"target acceptance rate. Default: 0.234."
75-
α::T=0.234
76+
α::T = 0.234
7677
"negative exponent of the adaptation decay rate. Default: `0.6`."
77-
γ::T=0.6
78+
γ::T = 0.6
7879
"initial lower-triangular Cholesky factor. Default: `nothing`."
79-
S::A=nothing
80+
S::A = nothing
8081
"lower bound on eigenvalues of the adapted Cholesky factor. Default: `0.0`."
81-
eigenvalue_lower_bound::T=0.0
82+
eigenvalue_lower_bound::T = 0.0
8283
"upper bound on eigenvalues of the adapted Cholesky factor. Default: `Inf`."
83-
eigenvalue_upper_bound::T=Inf
84+
eigenvalue_upper_bound::T = Inf
8485
end
8586

8687
"""
@@ -111,13 +112,22 @@ struct RobustAdaptiveMetropolisState{T1,L,A,T2,T3}
111112
end
112113

113114
AbstractMCMC.getparams(state::RobustAdaptiveMetropolisState) = state.x
114-
AbstractMCMC.setparams!!(state::RobustAdaptiveMetropolisState, x) = RobustAdaptiveMetropolisState(x, state.logprob, state.S, state.logα, state.η, state.iteration, state.isaccept)
115+
AbstractMCMC.setparams!!(state::RobustAdaptiveMetropolisState, x) =
116+
RobustAdaptiveMetropolisState(
117+
x,
118+
state.logprob,
119+
state.S,
120+
state.logα,
121+
state.η,
122+
state.iteration,
123+
state.isaccept,
124+
)
115125

116126
function ram_step_inner(
117127
rng::Random.AbstractRNG,
118128
model::AbstractMCMC.LogDensityModel,
119129
sampler::RobustAdaptiveMetropolis,
120-
state::RobustAdaptiveMetropolisState
130+
state::RobustAdaptiveMetropolisState,
121131
)
122132
# This is the initial state.
123133
f = model.logdensity
@@ -137,7 +147,12 @@ function ram_step_inner(
137147
return x_new, lp_new, U, logα, isaccept
138148
end
139149

140-
function ram_adapt(sampler::RobustAdaptiveMetropolis, state::RobustAdaptiveMetropolisState, logα::Real, U::AbstractVector)
150+
function ram_adapt(
151+
sampler::RobustAdaptiveMetropolis,
152+
state::RobustAdaptiveMetropolisState,
153+
logα::Real,
154+
U::AbstractVector,
155+
)
141156
Δα = exp(logα) - sampler.α
142157
S = state.S
143158
# TODO: Make this configurable by defining a more general path.
@@ -158,18 +173,25 @@ function AbstractMCMC.step(
158173
rng::Random.AbstractRNG,
159174
model::AbstractMCMC.LogDensityModel,
160175
sampler::RobustAdaptiveMetropolis;
161-
initial_params=nothing,
162-
kwargs...
176+
initial_params = nothing,
177+
kwargs...,
163178
)
164179
# This is the initial state.
165180
f = model.logdensity
166181
d = LogDensityProblems.dimension(f)
167182

168183
# Initial parameter state.
169-
T = initial_params === nothing ? eltype(sampler.γ) : Base.promote_type(eltype(sampler.γ), eltype(initial_params))
170-
x = initial_params === nothing ? rand(rng, T, d) : convert(AbstractVector{T}, initial_params)
184+
T =
185+
initial_params === nothing ? eltype(sampler.γ) :
186+
Base.promote_type(eltype(sampler.γ), eltype(initial_params))
187+
x =
188+
initial_params === nothing ? rand(rng, T, d) :
189+
convert(AbstractVector{T}, initial_params)
171190
# Initialize the Cholesky factor of the covariance matrix.
172-
S = LowerTriangular(sampler.S === nothing ? diagm(0 => ones(T, d)) : convert(AbstractMatrix{T}, sampler.S))
191+
S = LowerTriangular(
192+
sampler.S === nothing ? diagm(0 => ones(T, d)) :
193+
convert(AbstractMatrix{T}, sampler.S),
194+
)
173195

174196
# Construct the initial state.
175197
lp = LogDensityProblems.logdensity(f, x)
@@ -183,13 +205,22 @@ function AbstractMCMC.step(
183205
model::AbstractMCMC.LogDensityModel,
184206
sampler::RobustAdaptiveMetropolis,
185207
state::RobustAdaptiveMetropolisState;
186-
kwargs...
208+
kwargs...,
187209
)
188210
# Take the inner step.
189211
x_new, lp_new, U, logα, isaccept = ram_step_inner(rng, model, sampler, state)
190212
# Accept / reject the proposal.
191-
state_new = RobustAdaptiveMetropolisState(isaccept ? x_new : state.x, isaccept ? lp_new : state.logprob, state.S, logα, state.η, state.iteration + 1, isaccept)
192-
return AdvancedMH.Transition(state_new.x, state_new.logprob, state_new.isaccept), state_new
213+
state_new = RobustAdaptiveMetropolisState(
214+
isaccept ? x_new : state.x,
215+
isaccept ? lp_new : state.logprob,
216+
state.S,
217+
logα,
218+
state.η,
219+
state.iteration + 1,
220+
isaccept,
221+
)
222+
return AdvancedMH.Transition(state_new.x, state_new.logprob, state_new.isaccept),
223+
state_new
193224
end
194225

195226
function valid_eigenvalues(S, lower_bound, upper_bound)
@@ -205,20 +236,32 @@ function AbstractMCMC.step_warmup(
205236
model::AbstractMCMC.LogDensityModel,
206237
sampler::RobustAdaptiveMetropolis,
207238
state::RobustAdaptiveMetropolisState;
208-
kwargs...
239+
kwargs...,
209240
)
210241
# Take the inner step.
211242
x_new, lp_new, U, logα, isaccept = ram_step_inner(rng, model, sampler, state)
212243
# Adapt the proposal.
213244
S_new, η = ram_adapt(sampler, state, logα, U)
214245
# Check that `S_new` has eigenvalues in the desired range.
215-
if !valid_eigenvalues(S_new, sampler.eigenvalue_lower_bound, sampler.eigenvalue_upper_bound)
246+
if !valid_eigenvalues(
247+
S_new,
248+
sampler.eigenvalue_lower_bound,
249+
sampler.eigenvalue_upper_bound,
250+
)
216251
# In this case, we just keep the old `S` (p. 13 in Vihola, 2012).
217252
S_new = state.S
218253
end
219254

220255
# Update state.
221-
state_new = RobustAdaptiveMetropolisState(isaccept ? x_new : state.x, isaccept ? lp_new : state.logprob, S_new, logα, η, state.iteration + 1, isaccept)
222-
return AdvancedMH.Transition(state_new.x, state_new.logprob, state_new.isaccept), state_new
256+
state_new = RobustAdaptiveMetropolisState(
257+
isaccept ? x_new : state.x,
258+
isaccept ? lp_new : state.logprob,
259+
S_new,
260+
logα,
261+
η,
262+
state.iteration + 1,
263+
isaccept,
264+
)
265+
return AdvancedMH.Transition(state_new.x, state_new.logprob, state_new.isaccept),
266+
state_new
223267
end
224-

0 commit comments

Comments
 (0)