@@ -205,17 +205,23 @@ function AbstractMCMC.step(
205
205
return Transition (t. z, tstat), newstate
206
206
end
207
207
208
- struct SGHMCState{T<: AbstractVector{<:Real} }
208
+ struct SGHMCState{
209
+ TTrans<: Transition ,
210
+ TMetric<: AbstractMetric ,
211
+ TKernel<: AbstractMCMCKernel ,
212
+ TAdapt<: Adaptation.AbstractAdaptor ,
213
+ T<: AbstractVector{<:Real} ,
214
+ }
209
215
" Index of current iteration."
210
- i
216
+ i:: Int
211
217
" Current [`Transition`](@ref)."
212
- transition
218
+ transition:: TTrans
213
219
" Current [`AbstractMetric`](@ref), possibly adapted."
214
- metric
220
+ metric:: TMetric
215
221
" Current [`AbstractMCMCKernel`](@ref)."
216
- κ
222
+ κ:: TKernel
217
223
" Current [`AbstractAdaptor`](@ref)."
218
- adaptor
224
+ adaptor:: TAdapt
219
225
velocity:: T
220
226
end
221
227
getadaptor (state:: SGHMCState ) = state. adaptor
@@ -252,7 +258,7 @@ function AbstractMCMC.step(
252
258
# Get an initial sample.
253
259
h, t = AdvancedHMC. sample_init (rng, hamiltonian, initial_params)
254
260
255
- state = SGHMCState (0 , t, metric, κ, adaptor, initial_params, zero (initial_params) )
261
+ state = SGHMCState (0 , t, metric, κ, adaptor, initial_params)
256
262
257
263
return AbstractMCMC. step (rng, model, spl, state; kwargs... )
258
264
end
@@ -265,6 +271,14 @@ function AbstractMCMC.step(
265
271
n_adapts:: Int = 0 ,
266
272
kwargs... ,
267
273
)
274
+ if haskey (kwargs, :nadapts )
275
+ throw (
276
+ ArgumentError (
277
+ " keyword argument `nadapts` is unsupported. Please use `n_adapts` to specify the number of adaptation steps." ,
278
+ ),
279
+ )
280
+ end
281
+
268
282
i = state. i + 1
269
283
t_old = state. transition
270
284
adaptor = state. adaptor
@@ -289,14 +303,14 @@ function AbstractMCMC.step(
289
303
α = spl. momentum_decay
290
304
newv = (1 - α) .* v .+ η .* grad .+ sqrt (2 * η * α) .* randn (rng, eltype (v), length (v))
291
305
306
+ # Make new transition.
307
+ t = transition (rng, h, κ, t_old. z)
308
+
292
309
# Adapt h and spl.
293
310
tstat = stat (t)
294
311
h, κ, isadapted = adapt! (h, κ, adaptor, i, n_adapts, θ, tstat. acceptance_rate)
295
312
tstat = merge (tstat, (is_adapt= isadapted,))
296
313
297
- # Make new transition.
298
- t = transition (rng, h, κ, t_old. z)
299
-
300
314
# Compute next sample and state.
301
315
sample = Transition (t. z, tstat)
302
316
newstate = SGHMCState (i, t, h. metric, κ, adaptor, newv)
0 commit comments