From a279425b988def4585f2966dd0b04a8ec051289c Mon Sep 17 00:00:00 2001 From: Qingyu Qu <2283984853@qq.com> Date: Wed, 23 Apr 2025 19:11:42 +0800 Subject: [PATCH 1/8] Add Stochastic Gradient HMC --- src/AdvancedHMC.jl | 2 +- src/abstractmcmc.jl | 107 +++++++++++++++++++++++++++++++++++++++++++ src/constructors.jl | 45 ++++++++++++++++++ test/abstractmcmc.jl | 24 ++++++++++ 4 files changed, 177 insertions(+), 1 deletion(-) diff --git a/src/AdvancedHMC.jl b/src/AdvancedHMC.jl index 51cf5f57..786bb63d 100644 --- a/src/AdvancedHMC.jl +++ b/src/AdvancedHMC.jl @@ -125,7 +125,7 @@ include("sampler.jl") export sample include("constructors.jl") -export HMCSampler, HMC, NUTS, HMCDA +export HMCSampler, HMC, NUTS, HMCDA, SGHMC include("abstractmcmc.jl") diff --git a/src/abstractmcmc.jl b/src/abstractmcmc.jl index 1ae71a72..a94a9f7e 100644 --- a/src/abstractmcmc.jl +++ b/src/abstractmcmc.jl @@ -205,6 +205,105 @@ function AbstractMCMC.step( return Transition(t.z, tstat), newstate end +struct SGHMCState{T<:AbstractVector{<:Real}} + "Index of current iteration." + i + "Current [`Transition`](@ref)." + transition + "Current [`AbstractMetric`](@ref), possibly adapted." + metric + "Current [`AbstractMCMCKernel`](@ref)." + κ + "Current [`AbstractAdaptor`](@ref)." + adaptor + velocity::T +end +getadaptor(state::SGHMCState) = state.adaptor +getmetric(state::SGHMCState) = state.metric +getintegrator(state::SGHMCState) = state.κ.τ.integrator + +function AbstractMCMC.step( + rng::Random.AbstractRNG, + model::AbstractMCMC.LogDensityModel, + spl::SGHMC; + initial_params=nothing, + kwargs..., +) + # Unpack model + logdensity = model.logdensity + + # Define metric + metric = make_metric(spl, logdensity) + + # Construct the hamiltonian using the initial metric + hamiltonian = Hamiltonian(metric, model) + + # Compute initial sample and state. + initial_params = make_initial_params(rng, spl, logdensity, initial_params) + ϵ = make_step_size(rng, spl, hamiltonian, initial_params) + integrator = make_integrator(spl, ϵ) + + # Make kernel + κ = make_kernel(spl, integrator) + + # Make adaptor + adaptor = make_adaptor(spl, metric, integrator) + + # Get an initial sample. + h, t = AdvancedHMC.sample_init(rng, hamiltonian, initial_params) + + state = SGHMCState(0, t, metric, κ, adaptor, initial_params, zero(initial_params)) + + return AbstractMCMC.step(rng, model, spl, state; kwargs...) +end + +function AbstractMCMC.step( + rng::AbstractRNG, + model::AbstractMCMC.LogDensityModel, + spl::SGHMC, + state::SGHMCState; + n_adapts::Int=0, + kwargs..., +) + i = state.i + 1 + t_old = state.transition + adaptor = state.adaptor + κ = state.κ + metric = state.metric + + # Reconstruct hamiltonian. + h = Hamiltonian(metric, model) + + # Compute gradient of log density. + logdensity_and_gradient = Base.Fix1( + LogDensityProblems.logdensity_and_gradient, model.logdensity + ) + θ = t_old.z.θ + grad = last(logdensity_and_gradient(θ)) + + # Update latent variables and velocity according to + # equation (15) of Chen et al. (2014) + v = state.velocity + θ .+= v + η = spl.learning_rate + α = spl.momentum_decay + newv = (1 - α) .* v .+ η .* grad .+ sqrt(2 * η * α) .* randn(rng, eltype(v), length(v)) + + # Adapt h and spl. + tstat = stat(t) + h, κ, isadapted = adapt!(h, κ, adaptor, i, n_adapts, θ, tstat.acceptance_rate) + tstat = merge(tstat, (is_adapt=isadapted,)) + + # Make new transition. + t = transition(rng, h, κ, t_old.z) + + # Compute next sample and state. + sample = Transition(t.z, tstat) + newstate = SGHMCState(i, t, h.metric, κ, adaptor, newv) + + return sample, newstate +end + ################ ### Callback ### ################ @@ -392,6 +491,10 @@ function make_adaptor(spl::HMC, metric::AbstractMetric, integrator::AbstractInte return NoAdaptation() end +function make_adaptor(spl::SGHMC, metric::AbstractMetric, integrator::AbstractIntegrator) + return NoAdaptation() +end + function make_adaptor( spl::HMCSampler, metric::AbstractMetric, integrator::AbstractIntegrator ) @@ -417,3 +520,7 @@ end function make_kernel(spl::HMCSampler, integrator::AbstractIntegrator) return spl.κ end + +function make_kernel(spl::SGHMC, integrator::AbstractIntegrator) + return HMCKernel(Trajectory{EndPointTS}(integrator, FixedNSteps(spl.n_leapfrog))) +end diff --git a/src/constructors.jl b/src/constructors.jl index 4d60fdef..69b903ba 100644 --- a/src/constructors.jl +++ b/src/constructors.jl @@ -163,3 +163,48 @@ function HMCDA(δ, λ; integrator=:leapfrog, metric=:diagonal) end sampler_eltype(::HMCDA{T}) where {T} = T + +########### Static Hamiltonian Monte Carlo ########### + +############# +### SGHMC ### +############# +""" + SGHMC(learning_rate::Real, momentun_decay::Real, integrator = :leapfrog, metric = :diagonal) + +Stochastic Gradient Hamiltonian Monte Carlo sampler + +# Fields + +$(FIELDS) + +# Notes + +For more information, please view the following paper ([arXiv link](https://arxiv.org/abs/1402.4102)): + +- Chen, Tianqi, Emily Fox, and Carlos Guestrin. "Stochastic gradient hamiltonian monte carlo." International conference on machine learning. PMLR, 2014. +""" +struct SGHMC{T<:Real,I<:Union{Symbol,AbstractIntegrator},M<:Union{Symbol,AbstractMetric}} <: + AbstractHMCSampler + "Learning rate for the gradient descent." + learning_rate::T + "Momentum decay rate." + momentum_decay::T + "Number of leapfrog steps." + n_leapfrog::Int + "Choice of integrator, specified either using a `Symbol` or [`AbstractIntegrator`](@ref)" + integrator::I + "Choice of initial metric; `Symbol` means it is automatically initialised. The metric type will be preserved during automatic initialisation and adaption." + metric::M +end + +function SGHMC( + learning_rate, momentum_decay, n_leapfrog; integrator=:leapfrog, metric=:diagonal +) + T = determine_sampler_eltype( + learning_rate, momentum_decay, n_leapfrog, integrator, metric + ) + return SGHMC(T(learning_rate), T(momentum_decay), n_leapfrog, integrator, metric) +end + +sampler_eltype(::SGHMC{T}) where {T} = T diff --git a/test/abstractmcmc.jl b/test/abstractmcmc.jl index da448f78..50207cb0 100644 --- a/test/abstractmcmc.jl +++ b/test/abstractmcmc.jl @@ -10,6 +10,7 @@ using Statistics: mean nuts = NUTS(0.8) hmc = HMC(100; integrator=Leapfrog(0.05)) hmcda = HMCDA(0.8, 0.1) + sghmc = SGHMC(0.01, 0.1, 100) integrator = Leapfrog(1e-3) κ = AdvancedHMC.make_kernel(nuts, integrator) @@ -111,6 +112,29 @@ using Statistics: mean @test m_est_hmc ≈ [49 / 24, 7 / 6] atol = RNDATOL + samples_sghmc = AbstractMCMC.sample( + rng, + model, + sghmc, + n_adapts + n_samples; + n_adapts=n_adapts, + initial_params=θ_init, + progress=false, + verbose=false, + ) + + # Transform back to original space. + # NOTE: We're not correcting for the `logabsdetjac` here since, but + # we're only interested in the mean it doesn't matter. + for t in samples_sghmc + t.z.θ .= invlink_gdemo(t.z.θ) + end + m_est_sghmc = mean(samples_sghmc) do t + t.z.θ + end + + @test m_est_sghmc ≈ [49 / 24, 7 / 6] atol = RNDATOL + samples_custom = AbstractMCMC.sample( rng, model, From 5170e8559c397168c2a39d26be1d9c6bfe18beb8 Mon Sep 17 00:00:00 2001 From: Qingyu Qu <2283984853@qq.com> Date: Sun, 18 May 2025 17:35:23 +0800 Subject: [PATCH 2/8] Ensure type stability --- src/abstractmcmc.jl | 34 ++++++++++++++++++++++++---------- 1 file changed, 24 insertions(+), 10 deletions(-) diff --git a/src/abstractmcmc.jl b/src/abstractmcmc.jl index a94a9f7e..61befad1 100644 --- a/src/abstractmcmc.jl +++ b/src/abstractmcmc.jl @@ -205,17 +205,23 @@ function AbstractMCMC.step( return Transition(t.z, tstat), newstate end -struct SGHMCState{T<:AbstractVector{<:Real}} +struct SGHMCState{ + TTrans<:Transition, + TMetric<:AbstractMetric, + TKernel<:AbstractMCMCKernel, + TAdapt<:Adaptation.AbstractAdaptor, + T<:AbstractVector{<:Real}, +} "Index of current iteration." - i + i::Int "Current [`Transition`](@ref)." - transition + transition::TTrans "Current [`AbstractMetric`](@ref), possibly adapted." - metric + metric::TMetric "Current [`AbstractMCMCKernel`](@ref)." - κ + κ::TKernel "Current [`AbstractAdaptor`](@ref)." - adaptor + adaptor::TAdapt velocity::T end getadaptor(state::SGHMCState) = state.adaptor @@ -252,7 +258,7 @@ function AbstractMCMC.step( # Get an initial sample. h, t = AdvancedHMC.sample_init(rng, hamiltonian, initial_params) - state = SGHMCState(0, t, metric, κ, adaptor, initial_params, zero(initial_params)) + state = SGHMCState(0, t, metric, κ, adaptor, initial_params) return AbstractMCMC.step(rng, model, spl, state; kwargs...) end @@ -265,6 +271,14 @@ function AbstractMCMC.step( n_adapts::Int=0, kwargs..., ) + if haskey(kwargs, :nadapts) + throw( + ArgumentError( + "keyword argument `nadapts` is unsupported. Please use `n_adapts` to specify the number of adaptation steps.", + ), + ) + end + i = state.i + 1 t_old = state.transition adaptor = state.adaptor @@ -289,14 +303,14 @@ function AbstractMCMC.step( α = spl.momentum_decay newv = (1 - α) .* v .+ η .* grad .+ sqrt(2 * η * α) .* randn(rng, eltype(v), length(v)) + # Make new transition. + t = transition(rng, h, κ, t_old.z) + # Adapt h and spl. tstat = stat(t) h, κ, isadapted = adapt!(h, κ, adaptor, i, n_adapts, θ, tstat.acceptance_rate) tstat = merge(tstat, (is_adapt=isadapted,)) - # Make new transition. - t = transition(rng, h, κ, t_old.z) - # Compute next sample and state. sample = Transition(t.z, tstat) newstate = SGHMCState(i, t, h.metric, κ, adaptor, newv) From 5619cc46847f5794cdc8cb901b408a949401bce6 Mon Sep 17 00:00:00 2001 From: Qingyu Qu <2283984853@qq.com> Date: Thu, 22 May 2025 12:26:33 +0800 Subject: [PATCH 3/8] Fix inplace mistake --- src/abstractmcmc.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/abstractmcmc.jl b/src/abstractmcmc.jl index 61befad1..b21b351f 100644 --- a/src/abstractmcmc.jl +++ b/src/abstractmcmc.jl @@ -298,10 +298,10 @@ function AbstractMCMC.step( # Update latent variables and velocity according to # equation (15) of Chen et al. (2014) v = state.velocity - θ .+= v η = spl.learning_rate α = spl.momentum_decay newv = (1 - α) .* v .+ η .* grad .+ sqrt(2 * η * α) .* randn(rng, eltype(v), length(v)) + θ .+= newv # Make new transition. t = transition(rng, h, κ, t_old.z) From 715aefa2f9fa774fc68a54d3582d4dbe3e00ff1e Mon Sep 17 00:00:00 2001 From: Qingyu Qu <2283984853@qq.com> Date: Thu, 22 May 2025 12:57:24 +0800 Subject: [PATCH 4/8] Fix inplace mistake for theta --- src/abstractmcmc.jl | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/abstractmcmc.jl b/src/abstractmcmc.jl index b21b351f..a59e8002 100644 --- a/src/abstractmcmc.jl +++ b/src/abstractmcmc.jl @@ -292,7 +292,7 @@ function AbstractMCMC.step( logdensity_and_gradient = Base.Fix1( LogDensityProblems.logdensity_and_gradient, model.logdensity ) - θ = t_old.z.θ + θ = copy(t_old.z.θ) grad = last(logdensity_and_gradient(θ)) # Update latent variables and velocity according to @@ -304,7 +304,8 @@ function AbstractMCMC.step( θ .+= newv # Make new transition. - t = transition(rng, h, κ, t_old.z) + z = phasepoint(h, θ, v) + t = transition(rng, h, κ, z) # Adapt h and spl. tstat = stat(t) From d62c857eabde4de6317358310326e3eefea8cbaf Mon Sep 17 00:00:00 2001 From: Qingyu Qu <52615090+ErikQQY@users.noreply.github.com> Date: Tue, 27 May 2025 00:32:42 +0800 Subject: [PATCH 5/8] Bump major version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 599dae3c..77e85396 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "AdvancedHMC" uuid = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d" -version = "0.8.0" +version = "0.9.0" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" From 302e569625e2386e6d281144f3f91b6e75d60956 Mon Sep 17 00:00:00 2001 From: Qingyu Qu <52615090+ErikQQY@users.noreply.github.com> Date: Tue, 27 May 2025 00:37:13 +0800 Subject: [PATCH 6/8] Update HISTORY --- HISTORY.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/HISTORY.md b/HISTORY.md index 83d8882e..605f8b15 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -1,5 +1,9 @@ # AdvancedHMC Changelog +## 0.9.0 + + - Stochastic gradient based methods are supported in AdvancedHMC.jl, please note such methods will be removed in Turing.jl version 0.39.0. + ## 0.8.0 - To make an MCMC transtion from phasepoint `z` using trajectory `τ`(or HMCKernel `κ`) under Hamiltonian `h`, use `transition(h, τ, z)` or `transition(rng, h, τ, z)`(if using HMCKernel, use `transition(h, κ, z)` or `transition(rng, h, κ, z)`). From a90bdf9d1fc0b35ddfce48323688b2e2e99c5a11 Mon Sep 17 00:00:00 2001 From: Qingyu Qu <52615090+ErikQQY@users.noreply.github.com> Date: Tue, 27 May 2025 00:50:43 +0800 Subject: [PATCH 7/8] Bump compat in docs --- docs/Project.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/Project.toml b/docs/Project.toml index c48bd544..ba4848b5 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -4,6 +4,6 @@ Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" DocumenterCitations = "daee34ce-89f3-4625-b898-19384cb65244" [compat] -AdvancedHMC = "0.8" +AdvancedHMC = "0.9" Documenter = "1" -DocumenterCitations = "1" \ No newline at end of file +DocumenterCitations = "1" From 889747cdf54b9dab9cd368ce168efba4aa233124 Mon Sep 17 00:00:00 2001 From: Qingyu Qu <52615090+ErikQQY@users.noreply.github.com> Date: Tue, 27 May 2025 01:14:49 +0800 Subject: [PATCH 8/8] Better HISTORY --- HISTORY.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/HISTORY.md b/HISTORY.md index 605f8b15..0de72886 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -2,7 +2,7 @@ ## 0.9.0 - - Stochastic gradient based methods are supported in AdvancedHMC.jl, please note such methods will be removed in Turing.jl version 0.39.0. + - Stochastic gradient based methods `SGHMC` and `SGLD` are supported in AdvancedHMC.jl, please note there are similar methods with the same name in Turing.jl, so when using the two packages together, please specify the package exporting the method. ## 0.8.0