Skip to content

Commit 5170e85

Browse files
committed
Ensure type stability
1 parent 914f1a0 commit 5170e85

File tree

1 file changed

+24
-10
lines changed

1 file changed

+24
-10
lines changed

src/abstractmcmc.jl

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -205,17 +205,23 @@ function AbstractMCMC.step(
205205
return Transition(t.z, tstat), newstate
206206
end
207207

208-
struct SGHMCState{T<:AbstractVector{<:Real}}
208+
struct SGHMCState{
209+
TTrans<:Transition,
210+
TMetric<:AbstractMetric,
211+
TKernel<:AbstractMCMCKernel,
212+
TAdapt<:Adaptation.AbstractAdaptor,
213+
T<:AbstractVector{<:Real},
214+
}
209215
"Index of current iteration."
210-
i
216+
i::Int
211217
"Current [`Transition`](@ref)."
212-
transition
218+
transition::TTrans
213219
"Current [`AbstractMetric`](@ref), possibly adapted."
214-
metric
220+
metric::TMetric
215221
"Current [`AbstractMCMCKernel`](@ref)."
216-
κ
222+
κ::TKernel
217223
"Current [`AbstractAdaptor`](@ref)."
218-
adaptor
224+
adaptor::TAdapt
219225
velocity::T
220226
end
221227
getadaptor(state::SGHMCState) = state.adaptor
@@ -252,7 +258,7 @@ function AbstractMCMC.step(
252258
# Get an initial sample.
253259
h, t = AdvancedHMC.sample_init(rng, hamiltonian, initial_params)
254260

255-
state = SGHMCState(0, t, metric, κ, adaptor, initial_params, zero(initial_params))
261+
state = SGHMCState(0, t, metric, κ, adaptor, initial_params)
256262

257263
return AbstractMCMC.step(rng, model, spl, state; kwargs...)
258264
end
@@ -265,6 +271,14 @@ function AbstractMCMC.step(
265271
n_adapts::Int=0,
266272
kwargs...,
267273
)
274+
if haskey(kwargs, :nadapts)
275+
throw(
276+
ArgumentError(
277+
"keyword argument `nadapts` is unsupported. Please use `n_adapts` to specify the number of adaptation steps.",
278+
),
279+
)
280+
end
281+
268282
i = state.i + 1
269283
t_old = state.transition
270284
adaptor = state.adaptor
@@ -289,14 +303,14 @@ function AbstractMCMC.step(
289303
α = spl.momentum_decay
290304
newv = (1 - α) .* v .+ η .* grad .+ sqrt(2 * η * α) .* randn(rng, eltype(v), length(v))
291305

306+
# Make new transition.
307+
t = transition(rng, h, κ, t_old.z)
308+
292309
# Adapt h and spl.
293310
tstat = stat(t)
294311
h, κ, isadapted = adapt!(h, κ, adaptor, i, n_adapts, θ, tstat.acceptance_rate)
295312
tstat = merge(tstat, (is_adapt=isadapted,))
296313

297-
# Make new transition.
298-
t = transition(rng, h, κ, t_old.z)
299-
300314
# Compute next sample and state.
301315
sample = Transition(t.z, tstat)
302316
newstate = SGHMCState(i, t, h.metric, κ, adaptor, newv)

0 commit comments

Comments
 (0)