Skip to content
2 changes: 1 addition & 1 deletion src/AdvancedHMC.jl
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ include("sampler.jl")
export sample

include("constructors.jl")
export HMCSampler, HMC, NUTS, HMCDA
export HMCSampler, HMC, NUTS, HMCDA, SGHMC

include("abstractmcmc.jl")

Expand Down
122 changes: 122 additions & 0 deletions src/abstractmcmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,120 @@ function AbstractMCMC.step(
return Transition(t.z, tstat), newstate
end

struct SGHMCState{
TTrans<:Transition,
TMetric<:AbstractMetric,
TKernel<:AbstractMCMCKernel,
TAdapt<:Adaptation.AbstractAdaptor,
T<:AbstractVector{<:Real},
}
"Index of current iteration."
i::Int
"Current [`Transition`](@ref)."
transition::TTrans
"Current [`AbstractMetric`](@ref), possibly adapted."
metric::TMetric
"Current [`AbstractMCMCKernel`](@ref)."
κ::TKernel
"Current [`AbstractAdaptor`](@ref)."
adaptor::TAdapt
velocity::T
end
getadaptor(state::SGHMCState) = state.adaptor
getmetric(state::SGHMCState) = state.metric
getintegrator(state::SGHMCState) = state.κ.τ.integrator

function AbstractMCMC.step(
rng::Random.AbstractRNG,
model::AbstractMCMC.LogDensityModel,
spl::SGHMC;
initial_params=nothing,
kwargs...,
)
# Unpack model
logdensity = model.logdensity

# Define metric
metric = make_metric(spl, logdensity)

# Construct the hamiltonian using the initial metric
hamiltonian = Hamiltonian(metric, model)

# Compute initial sample and state.
initial_params = make_initial_params(rng, spl, logdensity, initial_params)
ϵ = make_step_size(rng, spl, hamiltonian, initial_params)
integrator = make_integrator(spl, ϵ)

# Make kernel
κ = make_kernel(spl, integrator)

# Make adaptor
adaptor = make_adaptor(spl, metric, integrator)

# Get an initial sample.
h, t = AdvancedHMC.sample_init(rng, hamiltonian, initial_params)

state = SGHMCState(0, t, metric, κ, adaptor, initial_params)

return AbstractMCMC.step(rng, model, spl, state; kwargs...)
end

function AbstractMCMC.step(
rng::AbstractRNG,
model::AbstractMCMC.LogDensityModel,
spl::SGHMC,
state::SGHMCState;
n_adapts::Int=0,
kwargs...,
)
if haskey(kwargs, :nadapts)
throw(
ArgumentError(
"keyword argument `nadapts` is unsupported. Please use `n_adapts` to specify the number of adaptation steps.",
),
)
end

i = state.i + 1
t_old = state.transition
adaptor = state.adaptor
κ = state.κ
metric = state.metric

# Reconstruct hamiltonian.
h = Hamiltonian(metric, model)

# Compute gradient of log density.
logdensity_and_gradient = Base.Fix1(
LogDensityProblems.logdensity_and_gradient, model.logdensity
)
θ = copy(t_old.z.θ)
grad = last(logdensity_and_gradient(θ))

# Update latent variables and velocity according to
# equation (15) of Chen et al. (2014)
v = state.velocity
η = spl.learning_rate
α = spl.momentum_decay
newv = (1 - α) .* v .+ η .* grad .+ sqrt(2 * η * α) .* randn(rng, eltype(v), length(v))
θ .+= newv

# Make new transition.
z = phasepoint(h, θ, v)
t = transition(rng, h, κ, z)

# Adapt h and spl.
tstat = stat(t)
h, κ, isadapted = adapt!(h, κ, adaptor, i, n_adapts, θ, tstat.acceptance_rate)
tstat = merge(tstat, (is_adapt=isadapted,))

# Compute next sample and state.
sample = Transition(t.z, tstat)
newstate = SGHMCState(i, t, h.metric, κ, adaptor, newv)

return sample, newstate
end

################
### Callback ###
################
Expand Down Expand Up @@ -392,6 +506,10 @@ function make_adaptor(spl::HMC, metric::AbstractMetric, integrator::AbstractInte
return NoAdaptation()
end

function make_adaptor(spl::SGHMC, metric::AbstractMetric, integrator::AbstractIntegrator)
return NoAdaptation()
end

function make_adaptor(
spl::HMCSampler, metric::AbstractMetric, integrator::AbstractIntegrator
)
Expand All @@ -417,3 +535,7 @@ end
function make_kernel(spl::HMCSampler, integrator::AbstractIntegrator)
return spl.κ
end

function make_kernel(spl::SGHMC, integrator::AbstractIntegrator)
return HMCKernel(Trajectory{EndPointTS}(integrator, FixedNSteps(spl.n_leapfrog)))
end
45 changes: 45 additions & 0 deletions src/constructors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -163,3 +163,48 @@ function HMCDA(δ, λ; integrator=:leapfrog, metric=:diagonal)
end

sampler_eltype(::HMCDA{T}) where {T} = T

########### Static Hamiltonian Monte Carlo ###########

#############
### SGHMC ###
#############
"""
SGHMC(learning_rate::Real, momentun_decay::Real, integrator = :leapfrog, metric = :diagonal)

Stochastic Gradient Hamiltonian Monte Carlo sampler

# Fields

$(FIELDS)

# Notes

For more information, please view the following paper ([arXiv link](https://arxiv.org/abs/1402.4102)):

- Chen, Tianqi, Emily Fox, and Carlos Guestrin. "Stochastic gradient hamiltonian monte carlo." International conference on machine learning. PMLR, 2014.
"""
struct SGHMC{T<:Real,I<:Union{Symbol,AbstractIntegrator},M<:Union{Symbol,AbstractMetric}} <:
AbstractHMCSampler
"Learning rate for the gradient descent."
learning_rate::T
"Momentum decay rate."
momentum_decay::T
"Number of leapfrog steps."
n_leapfrog::Int
"Choice of integrator, specified either using a `Symbol` or [`AbstractIntegrator`](@ref)"
integrator::I
"Choice of initial metric; `Symbol` means it is automatically initialised. The metric type will be preserved during automatic initialisation and adaption."
metric::M
end

function SGHMC(
learning_rate, momentum_decay, n_leapfrog; integrator=:leapfrog, metric=:diagonal
)
T = determine_sampler_eltype(
learning_rate, momentum_decay, n_leapfrog, integrator, metric
)
return SGHMC(T(learning_rate), T(momentum_decay), n_leapfrog, integrator, metric)
end

sampler_eltype(::SGHMC{T}) where {T} = T
24 changes: 24 additions & 0 deletions test/abstractmcmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ using Statistics: mean
nuts = NUTS(0.8)
hmc = HMC(100; integrator=Leapfrog(0.05))
hmcda = HMCDA(0.8, 0.1)
sghmc = SGHMC(0.01, 0.1, 100)

integrator = Leapfrog(1e-3)
κ = AdvancedHMC.make_kernel(nuts, integrator)
Expand Down Expand Up @@ -111,6 +112,29 @@ using Statistics: mean

@test m_est_hmc ≈ [49 / 24, 7 / 6] atol = RNDATOL

samples_sghmc = AbstractMCMC.sample(
rng,
model,
sghmc,
n_adapts + n_samples;
n_adapts=n_adapts,
initial_params=θ_init,
progress=false,
verbose=false,
)

# Transform back to original space.
# NOTE: We're not correcting for the `logabsdetjac` here since, but
# we're only interested in the mean it doesn't matter.
for t in samples_sghmc
t.z.θ .= invlink_gdemo(t.z.θ)
end
m_est_sghmc = mean(samples_sghmc) do t
t.z.θ
end

@test m_est_sghmc ≈ [49 / 24, 7 / 6] atol = RNDATOL

samples_custom = AbstractMCMC.sample(
rng,
model,
Expand Down
Loading