@@ -205,6 +205,105 @@ function AbstractMCMC.step(
205
205
return Transition (t. z, tstat), newstate
206
206
end
207
207
208
+ struct SGHMCState{T<: AbstractVector{<:Real} }
209
+ " Index of current iteration."
210
+ i
211
+ " Current [`Transition`](@ref)."
212
+ transition
213
+ " Current [`AbstractMetric`](@ref), possibly adapted."
214
+ metric
215
+ " Current [`AbstractMCMCKernel`](@ref)."
216
+ κ
217
+ " Current [`AbstractAdaptor`](@ref)."
218
+ adaptor
219
+ velocity:: T
220
+ end
221
+ getadaptor (state:: SGHMCState ) = state. adaptor
222
+ getmetric (state:: SGHMCState ) = state. metric
223
+ getintegrator (state:: SGHMCState ) = state. κ. τ. integrator
224
+
225
+ function AbstractMCMC. step (
226
+ rng:: Random.AbstractRNG ,
227
+ model:: AbstractMCMC.LogDensityModel ,
228
+ spl:: SGHMC ;
229
+ initial_params= nothing ,
230
+ kwargs... ,
231
+ )
232
+ # Unpack model
233
+ logdensity = model. logdensity
234
+
235
+ # Define metric
236
+ metric = make_metric (spl, logdensity)
237
+
238
+ # Construct the hamiltonian using the initial metric
239
+ hamiltonian = Hamiltonian (metric, model)
240
+
241
+ # Compute initial sample and state.
242
+ initial_params = make_initial_params (rng, spl, logdensity, initial_params)
243
+ ϵ = make_step_size (rng, spl, hamiltonian, initial_params)
244
+ integrator = make_integrator (spl, ϵ)
245
+
246
+ # Make kernel
247
+ κ = make_kernel (spl, integrator)
248
+
249
+ # Make adaptor
250
+ adaptor = make_adaptor (spl, metric, integrator)
251
+
252
+ # Get an initial sample.
253
+ h, t = AdvancedHMC. sample_init (rng, hamiltonian, initial_params)
254
+
255
+ state = SGHMCState (0 , t, metric, κ, adaptor, initial_params, zero (initial_params))
256
+
257
+ return AbstractMCMC. step (rng, model, spl, state; kwargs... )
258
+ end
259
+
260
+ function AbstractMCMC. step (
261
+ rng:: AbstractRNG ,
262
+ model:: AbstractMCMC.LogDensityModel ,
263
+ spl:: SGHMC ,
264
+ state:: SGHMCState ;
265
+ n_adapts:: Int = 0 ,
266
+ kwargs... ,
267
+ )
268
+ i = state. i + 1
269
+ t_old = state. transition
270
+ adaptor = state. adaptor
271
+ κ = state. κ
272
+ metric = state. metric
273
+
274
+ # Reconstruct hamiltonian.
275
+ h = Hamiltonian (metric, model)
276
+
277
+ # Compute gradient of log density.
278
+ logdensity_and_gradient = Base. Fix1 (
279
+ LogDensityProblems. logdensity_and_gradient, model. logdensity
280
+ )
281
+ θ = t_old. z. θ
282
+ grad = last (logdensity_and_gradient (θ))
283
+
284
+ # Update latent variables and velocity according to
285
+ # equation (15) of Chen et al. (2014)
286
+ v = state. velocity
287
+ θ .+ = v
288
+ η = spl. learning_rate
289
+ α = spl. momentum_decay
290
+ newv = (1 - α) .* v .+ η .* grad .+ sqrt (2 * η * α) .* randn (rng, eltype (v), length (v))
291
+
292
+ # Adapt h and spl.
293
+ tstat = stat (t)
294
+ h, κ, isadapted = adapt! (h, κ, adaptor, i, n_adapts, θ, tstat. acceptance_rate)
295
+ tstat = merge (tstat, (is_adapt= isadapted,))
296
+
297
+ # Make new transition.
298
+ t = transition (rng, h, κ, t_old. z)
299
+
300
+ # Compute next sample and state.
301
+ sample = Transition (t. z, tstat)
302
+ newstate = SGHMCState (i, t, h. metric, κ, adaptor, newv)
303
+
304
+ return sample, newstate
305
+ end
306
+
208
307
# ###############
209
308
# ## Callback ###
210
309
# ###############
@@ -392,6 +491,10 @@ function make_adaptor(spl::HMC, metric::AbstractMetric, integrator::AbstractInte
392
491
return NoAdaptation ()
393
492
end
394
493
494
+ function make_adaptor (spl:: SGHMC , metric:: AbstractMetric , integrator:: AbstractIntegrator )
495
+ return NoAdaptation ()
496
+ end
497
+
395
498
function make_adaptor (
396
499
spl:: HMCSampler , metric:: AbstractMetric , integrator:: AbstractIntegrator
397
500
)
417
520
function make_kernel (spl:: HMCSampler , integrator:: AbstractIntegrator )
418
521
return spl. κ
419
522
end
523
+
524
+ function make_kernel (spl:: SGHMC , integrator:: AbstractIntegrator )
525
+ return HMCKernel (Trajectory {EndPointTS} (integrator, FixedNSteps (spl. n_leapfrog)))
526
+ end
0 commit comments