From f69c3f0f444f65cb0a381d0c9109af1d51f9db82 Mon Sep 17 00:00:00 2001 From: Qingyu Qu <2283984853@qq.com> Date: Wed, 7 May 2025 18:18:09 +0800 Subject: [PATCH 1/3] Add vectorized HMC docs --- docs/make.jl | 1 + docs/src/get_started.md | 2 +- docs/src/vectorized.md | 28 ++++++++++++++++++ src/sampler.jl | 6 ++-- src/trajectory.jl | 7 +++-- test/demo.jl | 10 ++++--- test/integrator.jl | 5 ++-- test/trajectory.jl | 65 +++++++++++++++++------------------------ 8 files changed, 73 insertions(+), 51 deletions(-) create mode 100644 docs/src/vectorized.md diff --git a/docs/make.jl b/docs/make.jl index 62a1f624..b7b8845a 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -16,6 +16,7 @@ makedocs(; pages=[ "AdvancedHMC.jl" => "index.md", "Get Started" => "get_started.md", + "Vectorized HMC" => "vectorized.md", "Automatic Differentiation Backends" => "autodiff.md", "Detailed API" => "api.md", "Interfaces" => "interfaces.md", diff --git a/docs/src/get_started.md b/docs/src/get_started.md index 6629aa52..a645fc20 100644 --- a/docs/src/get_started.md +++ b/docs/src/get_started.md @@ -1,4 +1,4 @@ -# Sampling from a multivariate Gaussian using NUTS +# [Sampling from a multivariate Gaussian using NUTS](@id get_started) In this section, we demonstrate a minimal example of sampling from a multivariate Gaussian (10-dimensional) using the No U-Turn Sampler (NUTS). Below we describe the major components of the Hamiltonian system which are essential to sample using this approach: diff --git a/docs/src/vectorized.md b/docs/src/vectorized.md new file mode 100644 index 00000000..9cabed17 --- /dev/null +++ b/docs/src/vectorized.md @@ -0,0 +1,28 @@ +# Vectorized HMC Sampling + +In this section, we explain how to easily employ vectorized Hamiltonian Monte Carlo with AdvancedHMC.jl. Let's continue with the previous example in [getting-started](@ref get_started), we want to sample a multivariate Gaussian (10-dimensional) with multiple chains, we can simply utilize the vectorized initial parameters, leapfrod integrator, and metric. Here, the vectorized log density problems come from [MCMCLogDensityProblems.jl](https://github.com/chalk-lab/MCMCLogDensityProblems.jl) which is a library of common vectorized log density target distributions. + +```julia +using AdvancedHMC +using MCMCLogDensityProblems + +D = 5 +target = HighDimGaussian(D) +ℓπ(x) = logpdf(target, x) +∂ℓπ∂θ(x) = logpdf_grad(target, x) + +n_chains = 5 +θ_init = rand(D, n_chains) +ϵ = 0.1 +lfi = Leapfrog(fill(ϵ, n_chains)) +n_steps = 10 +n_samples = 20_000 +metric = DiagEuclideanMetric((D, n_chains)) +τ = Trajectory{EndPointTS}(lfi, FixedNSteps(n_steps)) +h = Hamiltonian(metric, ℓπ, ∂ℓπ∂θ) +samples, stats = sample(h, HMCKernel(τ), θ_init, n_samples; verbose=false) +``` + +!!! note + + `NUTS` sampler doesn't support vectorized sampling for now. diff --git a/src/sampler.jl b/src/sampler.jl index c0a42681..e0138819 100644 --- a/src/sampler.jl +++ b/src/sampler.jl @@ -117,7 +117,7 @@ function sample( drop_warmup=false, verbose::Bool=true, progress::Bool=false, - (pm_next!)::Function=pm_next!, + (pm_next!)::Function=(pm_next!), ) return sample( Random.default_rng(), @@ -130,7 +130,7 @@ function sample( drop_warmup=drop_warmup, verbose=verbose, progress=progress, - (pm_next!)=pm_next!, + (pm_next!)=(pm_next!), ) end @@ -168,7 +168,7 @@ function sample( drop_warmup=false, verbose::Bool=true, progress::Bool=false, - (pm_next!)::Function=pm_next!, + (pm_next!)::Function=(pm_next!), ) where {T<:AbstractVecOrMat{<:AbstractFloat}} @assert !(drop_warmup && (adaptor isa Adaptation.NoAdaptation)) "Cannot drop warmup samples if there is no adaptation phase." # Prepare containers to store sampling results diff --git a/src/trajectory.jl b/src/trajectory.jl index a7680760..aa8c90ca 100644 --- a/src/trajectory.jl +++ b/src/trajectory.jl @@ -133,8 +133,9 @@ $(TYPEDEF) Slice sampler for the starting single leaf tree. Slice variable is initialized. """ -SliceTS(rng::AbstractRNG, z0::PhasePoint) = +function SliceTS(rng::AbstractRNG, z0::PhasePoint) SliceTS(z0, neg_energy(z0) - Random.randexp(rng), 1) +end """ $(TYPEDEF) @@ -278,7 +279,7 @@ function transition( hamiltonian_energy=H, hamiltonian_energy_error=H - H0, # check numerical error in proposed phase point. - numerical_error=!all(isfinite, H′), + numerical_error=(!all(isfinite, H′)), ), stat(τ.integrator), ) @@ -717,7 +718,7 @@ function transition( ( n_steps=tree.nα, is_accept=true, - acceptance_rate=tree.sum_α / tree.nα, + acceptance_rate=(tree.sum_α / tree.nα), log_density=zcand.ℓπ.value, hamiltonian_energy=H, hamiltonian_energy_error=H - H0, diff --git a/test/demo.jl b/test/demo.jl index 98315daa..c9010a7f 100644 --- a/test/demo.jl +++ b/test/demo.jl @@ -10,8 +10,9 @@ using LinearAlgebra, ADTypes LogDensityProblems.logdensity(p::DemoProblem, θ) = logpdf(MvNormal(zeros(p.dim), I), θ) LogDensityProblems.dimension(p::DemoProblem) = p.dim - LogDensityProblems.capabilities(::Type{DemoProblem}) = - LogDensityProblems.LogDensityOrder{0}() + LogDensityProblems.capabilities(::Type{DemoProblem}) = LogDensityProblems.LogDensityOrder{ + 0 + }() # Choose parameter dimensionality and initial parameter value D = 10 @@ -66,8 +67,9 @@ end return -((1 - p.μ) / p.σ)^2 end LogDensityProblems.dimension(::DemoProblemComponentArrays) = 2 - LogDensityProblems.capabilities(::Type{DemoProblemComponentArrays}) = - LogDensityProblems.LogDensityOrder{0}() + LogDensityProblems.capabilities(::Type{DemoProblemComponentArrays}) = LogDensityProblems.LogDensityOrder{ + 0 + }() ℓπ = DemoProblemComponentArrays() diff --git a/test/integrator.jl b/test/integrator.jl index b9eb1407..f5a3dbea 100644 --- a/test/integrator.jl +++ b/test/integrator.jl @@ -112,8 +112,9 @@ using Statistics: mean LogDensityProblems.logdensity(::NegU, x) = -dot(x, x) / 2 LogDensityProblems.dimension(d::NegU) = d.dim - LogDensityProblems.capabilities(::Type{NegU}) = - LogDensityProblems.LogDensityOrder{0}() + LogDensityProblems.capabilities(::Type{NegU}) = LogDensityProblems.LogDensityOrder{ + 0 + }() negU = NegU(1) diff --git a/test/trajectory.jl b/test/trajectory.jl index 403fd446..4bf0ac4d 100644 --- a/test/trajectory.jl +++ b/test/trajectory.jl @@ -257,46 +257,35 @@ end traj_r = hcat(map(z -> z.r, traj_z)...) rho = cumsum(traj_r; dims=2) - ts_hand_isturn_fwd = - hand_isturn.( - Ref(traj_z[1]), - traj_z, - [rho[:, i] for i in 1:length(traj_z)], - Ref(1), - ) - ts_ahmc_isturn_fwd = - ahmc_isturn.( - Ref(h), - Ref(traj_z[1]), - traj_z, - [rho[:, i] for i in 1:length(traj_z)], - Ref(1), - ) + ts_hand_isturn_fwd = hand_isturn.( + Ref(traj_z[1]), traj_z, [rho[:, i] for i in 1:length(traj_z)], Ref(1) + ) + ts_ahmc_isturn_fwd = ahmc_isturn.( + Ref(h), + Ref(traj_z[1]), + traj_z, + [rho[:, i] for i in 1:length(traj_z)], + Ref(1), + ) - ts_hand_isturn_generalised_fwd = - hand_isturn_generalised.( - Ref(traj_z[1]), - traj_z, - [rho[:, i] for i in 1:length(traj_z)], - Ref(1), - ) - ts_ahmc_isturn_generalised_fwd = - ahmc_isturn_generalised.( - Ref(h), - Ref(traj_z[1]), - traj_z, - [rho[:, i] for i in 1:length(traj_z)], - Ref(1), - ) + ts_hand_isturn_generalised_fwd = hand_isturn_generalised.( + Ref(traj_z[1]), traj_z, [rho[:, i] for i in 1:length(traj_z)], Ref(1) + ) + ts_ahmc_isturn_generalised_fwd = ahmc_isturn_generalised.( + Ref(h), + Ref(traj_z[1]), + traj_z, + [rho[:, i] for i in 1:length(traj_z)], + Ref(1), + ) - ts_ahmc_isturn_strictgeneralised_fwd = - ahmc_isturn_strictgeneralised.( - Ref(h), - Ref(traj_z[1]), - traj_z, - [rho[:, i] for i in 1:length(traj_z)], - Ref(1), - ) + ts_ahmc_isturn_strictgeneralised_fwd = ahmc_isturn_strictgeneralised.( + Ref(h), + Ref(traj_z[1]), + traj_z, + [rho[:, i] for i in 1:length(traj_z)], + Ref(1), + ) check_subtree_u_turns.( Ref(h), Ref(traj_z[1]), traj_z, [rho[:, i] for i in 1:length(traj_z)] From c186b5d071f3b8f9e8f23fcdf9f26f1d92440023 Mon Sep 17 00:00:00 2001 From: Qingyu Qu <2283984853@qq.com> Date: Wed, 7 May 2025 22:42:05 +0800 Subject: [PATCH 2/3] Fix format issue --- src/sampler.jl | 6 ++--- src/trajectory.jl | 7 +++-- test/demo.jl | 10 +++---- test/integrator.jl | 5 ++-- test/trajectory.jl | 65 +++++++++++++++++++++++++++------------------- 5 files changed, 50 insertions(+), 43 deletions(-) diff --git a/src/sampler.jl b/src/sampler.jl index e0138819..c0a42681 100644 --- a/src/sampler.jl +++ b/src/sampler.jl @@ -117,7 +117,7 @@ function sample( drop_warmup=false, verbose::Bool=true, progress::Bool=false, - (pm_next!)::Function=(pm_next!), + (pm_next!)::Function=pm_next!, ) return sample( Random.default_rng(), @@ -130,7 +130,7 @@ function sample( drop_warmup=drop_warmup, verbose=verbose, progress=progress, - (pm_next!)=(pm_next!), + (pm_next!)=pm_next!, ) end @@ -168,7 +168,7 @@ function sample( drop_warmup=false, verbose::Bool=true, progress::Bool=false, - (pm_next!)::Function=(pm_next!), + (pm_next!)::Function=pm_next!, ) where {T<:AbstractVecOrMat{<:AbstractFloat}} @assert !(drop_warmup && (adaptor isa Adaptation.NoAdaptation)) "Cannot drop warmup samples if there is no adaptation phase." # Prepare containers to store sampling results diff --git a/src/trajectory.jl b/src/trajectory.jl index aa8c90ca..a7680760 100644 --- a/src/trajectory.jl +++ b/src/trajectory.jl @@ -133,9 +133,8 @@ $(TYPEDEF) Slice sampler for the starting single leaf tree. Slice variable is initialized. """ -function SliceTS(rng::AbstractRNG, z0::PhasePoint) +SliceTS(rng::AbstractRNG, z0::PhasePoint) = SliceTS(z0, neg_energy(z0) - Random.randexp(rng), 1) -end """ $(TYPEDEF) @@ -279,7 +278,7 @@ function transition( hamiltonian_energy=H, hamiltonian_energy_error=H - H0, # check numerical error in proposed phase point. - numerical_error=(!all(isfinite, H′)), + numerical_error=!all(isfinite, H′), ), stat(τ.integrator), ) @@ -718,7 +717,7 @@ function transition( ( n_steps=tree.nα, is_accept=true, - acceptance_rate=(tree.sum_α / tree.nα), + acceptance_rate=tree.sum_α / tree.nα, log_density=zcand.ℓπ.value, hamiltonian_energy=H, hamiltonian_energy_error=H - H0, diff --git a/test/demo.jl b/test/demo.jl index c9010a7f..98315daa 100644 --- a/test/demo.jl +++ b/test/demo.jl @@ -10,9 +10,8 @@ using LinearAlgebra, ADTypes LogDensityProblems.logdensity(p::DemoProblem, θ) = logpdf(MvNormal(zeros(p.dim), I), θ) LogDensityProblems.dimension(p::DemoProblem) = p.dim - LogDensityProblems.capabilities(::Type{DemoProblem}) = LogDensityProblems.LogDensityOrder{ - 0 - }() + LogDensityProblems.capabilities(::Type{DemoProblem}) = + LogDensityProblems.LogDensityOrder{0}() # Choose parameter dimensionality and initial parameter value D = 10 @@ -67,9 +66,8 @@ end return -((1 - p.μ) / p.σ)^2 end LogDensityProblems.dimension(::DemoProblemComponentArrays) = 2 - LogDensityProblems.capabilities(::Type{DemoProblemComponentArrays}) = LogDensityProblems.LogDensityOrder{ - 0 - }() + LogDensityProblems.capabilities(::Type{DemoProblemComponentArrays}) = + LogDensityProblems.LogDensityOrder{0}() ℓπ = DemoProblemComponentArrays() diff --git a/test/integrator.jl b/test/integrator.jl index f5a3dbea..b9eb1407 100644 --- a/test/integrator.jl +++ b/test/integrator.jl @@ -112,9 +112,8 @@ using Statistics: mean LogDensityProblems.logdensity(::NegU, x) = -dot(x, x) / 2 LogDensityProblems.dimension(d::NegU) = d.dim - LogDensityProblems.capabilities(::Type{NegU}) = LogDensityProblems.LogDensityOrder{ - 0 - }() + LogDensityProblems.capabilities(::Type{NegU}) = + LogDensityProblems.LogDensityOrder{0}() negU = NegU(1) diff --git a/test/trajectory.jl b/test/trajectory.jl index 4bf0ac4d..403fd446 100644 --- a/test/trajectory.jl +++ b/test/trajectory.jl @@ -257,35 +257,46 @@ end traj_r = hcat(map(z -> z.r, traj_z)...) rho = cumsum(traj_r; dims=2) - ts_hand_isturn_fwd = hand_isturn.( - Ref(traj_z[1]), traj_z, [rho[:, i] for i in 1:length(traj_z)], Ref(1) - ) - ts_ahmc_isturn_fwd = ahmc_isturn.( - Ref(h), - Ref(traj_z[1]), - traj_z, - [rho[:, i] for i in 1:length(traj_z)], - Ref(1), - ) + ts_hand_isturn_fwd = + hand_isturn.( + Ref(traj_z[1]), + traj_z, + [rho[:, i] for i in 1:length(traj_z)], + Ref(1), + ) + ts_ahmc_isturn_fwd = + ahmc_isturn.( + Ref(h), + Ref(traj_z[1]), + traj_z, + [rho[:, i] for i in 1:length(traj_z)], + Ref(1), + ) - ts_hand_isturn_generalised_fwd = hand_isturn_generalised.( - Ref(traj_z[1]), traj_z, [rho[:, i] for i in 1:length(traj_z)], Ref(1) - ) - ts_ahmc_isturn_generalised_fwd = ahmc_isturn_generalised.( - Ref(h), - Ref(traj_z[1]), - traj_z, - [rho[:, i] for i in 1:length(traj_z)], - Ref(1), - ) + ts_hand_isturn_generalised_fwd = + hand_isturn_generalised.( + Ref(traj_z[1]), + traj_z, + [rho[:, i] for i in 1:length(traj_z)], + Ref(1), + ) + ts_ahmc_isturn_generalised_fwd = + ahmc_isturn_generalised.( + Ref(h), + Ref(traj_z[1]), + traj_z, + [rho[:, i] for i in 1:length(traj_z)], + Ref(1), + ) - ts_ahmc_isturn_strictgeneralised_fwd = ahmc_isturn_strictgeneralised.( - Ref(h), - Ref(traj_z[1]), - traj_z, - [rho[:, i] for i in 1:length(traj_z)], - Ref(1), - ) + ts_ahmc_isturn_strictgeneralised_fwd = + ahmc_isturn_strictgeneralised.( + Ref(h), + Ref(traj_z[1]), + traj_z, + [rho[:, i] for i in 1:length(traj_z)], + Ref(1), + ) check_subtree_u_turns.( Ref(h), Ref(traj_z[1]), traj_z, [rho[:, i] for i in 1:length(traj_z)] From 0563bed01cde2b6890c89ca757858e1b60d25eda Mon Sep 17 00:00:00 2001 From: Qingyu Qu <2283984853@qq.com> Date: Fri, 9 May 2025 13:37:28 +0800 Subject: [PATCH 3/3] Add more details --- docs/src/vectorized.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/src/vectorized.md b/docs/src/vectorized.md index 9cabed17..d883d730 100644 --- a/docs/src/vectorized.md +++ b/docs/src/vectorized.md @@ -1,12 +1,12 @@ # Vectorized HMC Sampling -In this section, we explain how to easily employ vectorized Hamiltonian Monte Carlo with AdvancedHMC.jl. Let's continue with the previous example in [getting-started](@ref get_started), we want to sample a multivariate Gaussian (10-dimensional) with multiple chains, we can simply utilize the vectorized initial parameters, leapfrod integrator, and metric. Here, the vectorized log density problems come from [MCMCLogDensityProblems.jl](https://github.com/chalk-lab/MCMCLogDensityProblems.jl) which is a library of common vectorized log density target distributions. +In this section, we explain how to easily employ vectorized Hamiltonian Monte Carlo with AdvancedHMC.jl. Let's continue with the previous example in [getting-started](@ref get_started), we want to sample a multivariate Gaussian (10-dimensional) with multiple chains, we can simply specify the number of chains in initial parameters, leapfrog integrator, and metric to tell AdvanceHMC.jl how many chains we want to sample. Here, the vectorized multivariate Gaussian log density problem come from [MCMCLogDensityProblems.jl](https://github.com/chalk-lab/MCMCLogDensityProblems.jl) which is a library of common log density target distributions designed for vectorized sampling. ```julia using AdvancedHMC using MCMCLogDensityProblems -D = 5 +D = 10 target = HighDimGaussian(D) ℓπ(x) = logpdf(target, x) ∂ℓπ∂θ(x) = logpdf_grad(target, x) @@ -25,4 +25,4 @@ samples, stats = sample(h, HMCKernel(τ), θ_init, n_samples; verbose=false) !!! note - `NUTS` sampler doesn't support vectorized sampling for now. + Vectorized sampling only support static HMC, which means samplers like `NUTS` should not be used for vectorized sampling for now.