Skip to content
Closed
36 changes: 34 additions & 2 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,47 @@ This modularity means that different HMC variants can be easily constructed by c
- Diagonal metric: `DiagEuclideanMetric(dim)`
- Dense metric: `DenseEuclideanMetric(dim)`

where `dim` is the dimensionality of the sampling space.
where `dim` is the dimension of the sampling space.

Furthermore, there is experimental support for Riemannian Manifold HMC (RMHMC) with position-dependent metrics. Two metric types are provided:

#### `RiemannianMetric`

For user-provided **positive-definite** metrics (e.g., Fisher information matrix):

```julia
RiemannianMetric(dim, calc_G, calc_∂G∂θ)
```

- `dim`: the dimension of the sampling space (as a tuple, e.g., `(d,)`),
- `calc_G`: a function `θ -> G` returning the positive-definite metric matrix at position `θ`,
- `calc_∂G∂θ`: a function `θ -> ∂G∂θ` returning the Jacobian of `calc_G`, where `∂G∂θ[i,j,k]` is `∂G[i,j]/∂θ[k]`.

#### `SoftAbsRiemannianMetric`

For Hessian-based metrics with SoftAbs regularization, which transforms a symmetric (but not necessarily positive-definite) matrix into a positive-definite one:

```julia
SoftAbsRiemannianMetric(dim, calc_H, calc_∂H∂θ; α=1000.0)
```

- `dim`: the dimension of the sampling space (as a tuple, e.g., `(d,)`),
- `calc_H`: a function `θ -> H` returning the Hessian of the negative log density at position `θ`,
- `calc_∂H∂θ`: a function `θ -> ∂H∂θ` returning the Jacobian of `calc_H`,
- `α`: the SoftAbs regularization parameter (higher values give sharper regularization).

The SoftAbs transformation converts `H` to `G = Q * Diagonal(softabs.(λ, α)) * Q'` where `(λ, Q)` is the eigendecomposition and `softabs(λ, α) = λ * coth(α * λ)`.

Both Riemannian metrics require the `GeneralizedLeapfrog` integrator, which uses implicit fixed-point iterations to handle the position-dependent metric.

### [Integrator (`integrator`)](@id integrator)

- Ordinary leapfrog integrator: `Leapfrog(ϵ)`
- Jittered leapfrog integrator with jitter rate `n`: `JitteredLeapfrog(ϵ, n)`
- Tempered leapfrog integrator with tempering rate `a`: `TemperedLeapfrog(ϵ, a)`
- Generalized leapfrog integrator for Riemannian metrics: `GeneralizedLeapfrog(ϵ, n)`

where `ϵ` is the step size of leapfrog integration.
where `ϵ` is the step size of leapfrog integration and `n` is the number of fixed-point iterations (for `GeneralizedLeapfrog`). The `GeneralizedLeapfrog` integrator is required for Riemannian metrics and uses implicit fixed-point iterations to handle position-dependent metrics.

### [Kernel (`kernel`)](@id kernel)

Expand Down
31 changes: 28 additions & 3 deletions src/AdvancedHMC.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,20 @@ module AdvancedHMC

using Statistics: mean, var, middle
using LinearAlgebra:
Symmetric, UpperTriangular, mul!, ldiv!, dot, I, diag, cholesky, UniformScaling
Symmetric,
UpperTriangular,
Diagonal,
mul!,
ldiv!,
dot,
I,
diag,
cholesky,
UniformScaling,
logdet,
tr,
eigen,
diagm
using StatsFuns: logaddexp, logsumexp, loghalf
using Random: Random, AbstractRNG
using ProgressMeter: ProgressMeter
Expand Down Expand Up @@ -40,7 +53,7 @@ struct GaussianKinetic <: AbstractKinetic end
export GaussianKinetic

include("metric.jl")
export UnitEuclideanMetric, DiagEuclideanMetric, DenseEuclideanMetric
export UnitEuclideanMetric, DiagEuclideanMetric, DenseEuclideanMetric, DenseRiemannianMetric

include("hamiltonian.jl")
export Hamiltonian
Expand All @@ -50,6 +63,13 @@ export Leapfrog, JitteredLeapfrog, TemperedLeapfrog
include("riemannian/integrator.jl")
export GeneralizedLeapfrog

include("riemannian/metric.jl")
export RiemannianMetric, SoftAbsRiemannianMetric
# Deprecated exports (for backward compatibility)
export IdentityMap, SoftAbsMap, DenseRiemannianMetric

include("riemannian/hamiltonian.jl")

include("trajectory.jl")
export Trajectory,
HMCKernel,
Expand All @@ -72,7 +92,12 @@ export find_good_eps
include("adaptation/Adaptation.jl")
using .Adaptation
import .Adaptation:
StepSizeAdaptor, MassMatrixAdaptor, StanHMCAdaptor, NesterovDualAveraging, NoAdaptation, PositionOrPhasePoint
StepSizeAdaptor,
MassMatrixAdaptor,
StanHMCAdaptor,
NesterovDualAveraging,
NoAdaptation,
PositionOrPhasePoint

# Helpers for initializing adaptors via AHMC structs

Expand Down
Loading
Loading