Skip to content

Commit a279425

Browse files
committed
Add Stochastic Gradient HMC
1 parent ce8fa6f commit a279425

File tree

4 files changed

+177
-1
lines changed

4 files changed

+177
-1
lines changed

src/AdvancedHMC.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ include("sampler.jl")
125125
export sample
126126

127127
include("constructors.jl")
128-
export HMCSampler, HMC, NUTS, HMCDA
128+
export HMCSampler, HMC, NUTS, HMCDA, SGHMC
129129

130130
include("abstractmcmc.jl")
131131

src/abstractmcmc.jl

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,105 @@ function AbstractMCMC.step(
205205
return Transition(t.z, tstat), newstate
206206
end
207207

208+
struct SGHMCState{T<:AbstractVector{<:Real}}
209+
"Index of current iteration."
210+
i
211+
"Current [`Transition`](@ref)."
212+
transition
213+
"Current [`AbstractMetric`](@ref), possibly adapted."
214+
metric
215+
"Current [`AbstractMCMCKernel`](@ref)."
216+
κ
217+
"Current [`AbstractAdaptor`](@ref)."
218+
adaptor
219+
velocity::T
220+
end
221+
getadaptor(state::SGHMCState) = state.adaptor
222+
getmetric(state::SGHMCState) = state.metric
223+
getintegrator(state::SGHMCState) = state.κ.τ.integrator
224+
225+
function AbstractMCMC.step(
226+
rng::Random.AbstractRNG,
227+
model::AbstractMCMC.LogDensityModel,
228+
spl::SGHMC;
229+
initial_params=nothing,
230+
kwargs...,
231+
)
232+
# Unpack model
233+
logdensity = model.logdensity
234+
235+
# Define metric
236+
metric = make_metric(spl, logdensity)
237+
238+
# Construct the hamiltonian using the initial metric
239+
hamiltonian = Hamiltonian(metric, model)
240+
241+
# Compute initial sample and state.
242+
initial_params = make_initial_params(rng, spl, logdensity, initial_params)
243+
ϵ = make_step_size(rng, spl, hamiltonian, initial_params)
244+
integrator = make_integrator(spl, ϵ)
245+
246+
# Make kernel
247+
κ = make_kernel(spl, integrator)
248+
249+
# Make adaptor
250+
adaptor = make_adaptor(spl, metric, integrator)
251+
252+
# Get an initial sample.
253+
h, t = AdvancedHMC.sample_init(rng, hamiltonian, initial_params)
254+
255+
state = SGHMCState(0, t, metric, κ, adaptor, initial_params, zero(initial_params))
256+
257+
return AbstractMCMC.step(rng, model, spl, state; kwargs...)
258+
end
259+
260+
function AbstractMCMC.step(
261+
rng::AbstractRNG,
262+
model::AbstractMCMC.LogDensityModel,
263+
spl::SGHMC,
264+
state::SGHMCState;
265+
n_adapts::Int=0,
266+
kwargs...,
267+
)
268+
i = state.i + 1
269+
t_old = state.transition
270+
adaptor = state.adaptor
271+
κ = state.κ
272+
metric = state.metric
273+
274+
# Reconstruct hamiltonian.
275+
h = Hamiltonian(metric, model)
276+
277+
# Compute gradient of log density.
278+
logdensity_and_gradient = Base.Fix1(
279+
LogDensityProblems.logdensity_and_gradient, model.logdensity
280+
)
281+
θ = t_old.z.θ
282+
grad = last(logdensity_and_gradient(θ))
283+
284+
# Update latent variables and velocity according to
285+
# equation (15) of Chen et al. (2014)
286+
v = state.velocity
287+
θ .+= v
288+
η = spl.learning_rate
289+
α = spl.momentum_decay
290+
newv = (1 - α) .* v .+ η .* grad .+ sqrt(2 * η * α) .* randn(rng, eltype(v), length(v))
291+
292+
# Adapt h and spl.
293+
tstat = stat(t)
294+
h, κ, isadapted = adapt!(h, κ, adaptor, i, n_adapts, θ, tstat.acceptance_rate)
295+
tstat = merge(tstat, (is_adapt=isadapted,))
296+
297+
# Make new transition.
298+
t = transition(rng, h, κ, t_old.z)
299+
300+
# Compute next sample and state.
301+
sample = Transition(t.z, tstat)
302+
newstate = SGHMCState(i, t, h.metric, κ, adaptor, newv)
303+
304+
return sample, newstate
305+
end
306+
208307
################
209308
### Callback ###
210309
################
@@ -392,6 +491,10 @@ function make_adaptor(spl::HMC, metric::AbstractMetric, integrator::AbstractInte
392491
return NoAdaptation()
393492
end
394493

494+
function make_adaptor(spl::SGHMC, metric::AbstractMetric, integrator::AbstractIntegrator)
495+
return NoAdaptation()
496+
end
497+
395498
function make_adaptor(
396499
spl::HMCSampler, metric::AbstractMetric, integrator::AbstractIntegrator
397500
)
@@ -417,3 +520,7 @@ end
417520
function make_kernel(spl::HMCSampler, integrator::AbstractIntegrator)
418521
return spl.κ
419522
end
523+
524+
function make_kernel(spl::SGHMC, integrator::AbstractIntegrator)
525+
return HMCKernel(Trajectory{EndPointTS}(integrator, FixedNSteps(spl.n_leapfrog)))
526+
end

src/constructors.jl

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,3 +163,48 @@ function HMCDA(δ, λ; integrator=:leapfrog, metric=:diagonal)
163163
end
164164

165165
sampler_eltype(::HMCDA{T}) where {T} = T
166+
167+
########### Static Hamiltonian Monte Carlo ###########
168+
169+
#############
170+
### SGHMC ###
171+
#############
172+
"""
173+
SGHMC(learning_rate::Real, momentun_decay::Real, integrator = :leapfrog, metric = :diagonal)
174+
175+
Stochastic Gradient Hamiltonian Monte Carlo sampler
176+
177+
# Fields
178+
179+
$(FIELDS)
180+
181+
# Notes
182+
183+
For more information, please view the following paper ([arXiv link](https://arxiv.org/abs/1402.4102)):
184+
185+
- Chen, Tianqi, Emily Fox, and Carlos Guestrin. "Stochastic gradient hamiltonian monte carlo." International conference on machine learning. PMLR, 2014.
186+
"""
187+
struct SGHMC{T<:Real,I<:Union{Symbol,AbstractIntegrator},M<:Union{Symbol,AbstractMetric}} <:
188+
AbstractHMCSampler
189+
"Learning rate for the gradient descent."
190+
learning_rate::T
191+
"Momentum decay rate."
192+
momentum_decay::T
193+
"Number of leapfrog steps."
194+
n_leapfrog::Int
195+
"Choice of integrator, specified either using a `Symbol` or [`AbstractIntegrator`](@ref)"
196+
integrator::I
197+
"Choice of initial metric; `Symbol` means it is automatically initialised. The metric type will be preserved during automatic initialisation and adaption."
198+
metric::M
199+
end
200+
201+
function SGHMC(
202+
learning_rate, momentum_decay, n_leapfrog; integrator=:leapfrog, metric=:diagonal
203+
)
204+
T = determine_sampler_eltype(
205+
learning_rate, momentum_decay, n_leapfrog, integrator, metric
206+
)
207+
return SGHMC(T(learning_rate), T(momentum_decay), n_leapfrog, integrator, metric)
208+
end
209+
210+
sampler_eltype(::SGHMC{T}) where {T} = T

test/abstractmcmc.jl

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ using Statistics: mean
1010
nuts = NUTS(0.8)
1111
hmc = HMC(100; integrator=Leapfrog(0.05))
1212
hmcda = HMCDA(0.8, 0.1)
13+
sghmc = SGHMC(0.01, 0.1, 100)
1314

1415
integrator = Leapfrog(1e-3)
1516
κ = AdvancedHMC.make_kernel(nuts, integrator)
@@ -111,6 +112,29 @@ using Statistics: mean
111112

112113
@test m_est_hmc [49 / 24, 7 / 6] atol = RNDATOL
113114

115+
samples_sghmc = AbstractMCMC.sample(
116+
rng,
117+
model,
118+
sghmc,
119+
n_adapts + n_samples;
120+
n_adapts=n_adapts,
121+
initial_params=θ_init,
122+
progress=false,
123+
verbose=false,
124+
)
125+
126+
# Transform back to original space.
127+
# NOTE: We're not correcting for the `logabsdetjac` here since, but
128+
# we're only interested in the mean it doesn't matter.
129+
for t in samples_sghmc
130+
t.z.θ .= invlink_gdemo(t.z.θ)
131+
end
132+
m_est_sghmc = mean(samples_sghmc) do t
133+
t.z.θ
134+
end
135+
136+
@test m_est_sghmc [49 / 24, 7 / 6] atol = RNDATOL
137+
114138
samples_custom = AbstractMCMC.sample(
115139
rng,
116140
model,

0 commit comments

Comments
 (0)