Skip to content

Unified Riemannian Metric API#485

Open
THargreaves wants to merge 23 commits intomainfrom
th/unified-rhmc
Open

Unified Riemannian Metric API#485
THargreaves wants to merge 23 commits intomainfrom
th/unified-rhmc

Conversation

@THargreaves
Copy link

Overview

Insert relevant XKCD — many apologies guys but I felt like another RHMC PR was needed.

I've had a good look through Qingyu, Niko and Jamie's and found them to be a great start with some useful ideas on tidying up the RHMC interface.

I couldn't however help but feel like they were all doing some small (though useful) tweaks without fundamentally addressing the poor interface design of the current RHMC implementation.

This PR builds on Niko's PR (which builds on Yingqu's PR) whilst pulling in some ideas from Jamie's, whilst fundamentally reshaping the RHMC interface so it is philosophically sound.

The high-level takeaway of these changes is that RHMC now works with both SoftAbs eigendecomposition of G as well as Cholesky decompositions of guaranteed PD G, whilst sharing 90% of their code paths.

The key insight is that the gradient computation can be unified by introducing two dispatch functions (logdet_grad_matrix and kinetic_grad_matrix) that return matrices encoding the appropriate chain rule for each metric type.

We also introduce the notion of a pre-metric — for PD metric this is just G, for SoftAbs this is H. This abstraction allows the SoftAbs hamiltonian derivatives to be computed with optimal caching (since we can pull out the derivative part of G into logdet_grad_matrix leaving just H behind as in Betancourt's paper).

I've added tests for energy conservation and verifying that the samplers run but it would be nice to add some validity tests too. Could this be something @J-Price-3 could look at?

Please feel free to tear this PR apart; it's just a draft and I'm open to changes.

New Types

RiemannianMetric

For user-provided positive-definite metrics (e.g., Fisher information matrix):

calc_G = θ -> fisher_information(θ)
calc_∂G∂θ = θ -> autodiff_jacobian(calc_G, θ)
metric = RiemannianMetric((d,), calc_G, calc_∂G∂θ)

SoftAbsRiemannianMetric

For Hessian-based metrics with SoftAbs regularization:

calc_H = θ -> hessian_of_neg_log_density(θ)
calc_∂H∂θ = θ -> autodiff_jacobian(calc_H, θ)
metric = SoftAbsRiemannianMetric((d,), calc_H, calc_∂H∂θ, α=20.0)

SoftAbsEval

Cached result of evaluating a SoftAbs metric, storing:

  • Q: Eigenvectors
  • softabsλ: Transformed eigenvalues λ * coth(αλ)
  • J: Divided difference matrix for gradient chain rule
  • M_logdet: Precomputed matrix for logdet gradient

Key Design Decisions

Unified Gradient Computation

The Hamiltonian gradient ∂H/∂θ has three terms:

  1. Potential energy gradient: -∂ℓπ/∂θ
  2. Logdet term: 0.5 * tr(M_logdet * ∂P/∂θ)
  3. Kinetic term: -0.5 * tr(M_kinetic * ∂P/∂θ)

Where P is the "pre-metric" (G for RiemannianMetric, H for SoftAbsRiemannianMetric).

The dispatch functions logdet_grad_matrix(G) and kinetic_grad_matrix(G, r) return M_logdet and M_kinetic respectively, with the J matrix (divided difference formula) absorbed to handle the chain rule through the SoftAbs transformation.

O(n³) Complexity

By precomputing J-dependent matrices, we maintain O(n³) complexity per gradient evaluation rather than O(n⁴). The trace products are computed efficiently using tr(A*B) = sum(A' .* B) in O(n²).

Caching for Fixed-Point Iterations

The generalized leapfrog integrator requires fixed-point iterations where θ is fixed but r varies. RiemannianGradCache stores all θ-dependent computations (eigendecomposition, logdet gradient terms) to avoid redundant work.

Shared Metric Evaluation

neg_energy and ∂H∂r both need the evaluated metric. An optional G_eval parameter allows sharing the computation, and phasepoint uses this to compute both efficiently.

Deprecations

The following types are deprecated with warnings:

  • DenseRiemannianMetric → Use RiemannianMetric or SoftAbsRiemannianMetric
  • IdentityMap → Use RiemannianMetric directly
  • SoftAbsMap → Use SoftAbsRiemannianMetric

Existing code using the deprecated API will continue to work but will emit deprecation warnings — I'd be quite keen to remove these in the long-run.

References

  • Girolami & Calderhead (2011). "Riemann manifold Langevin and Hamiltonian Monte Carlo methods"
  • Betancourt (2012). "A general metric for Riemannian manifold Hamiltonian Monte Carlo"

Comment on lines +17 to +19
- `premetric`: a function which, for a given posterior position `pos`, computes either
a) a symmetric, **positive definite** matrix acting as the position dependent Riemannian metric (if `metric_map = IdentityMap()`), or
b) a symmetric, **not necessarily positive definite** matrix acting as the position dependent Riemannian metric after being passed through the `metric_map` argument, which will have to ensure that its return value *is* positive definite (like `metric_map = SoftAbsMap(alpha)`),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
- `premetric`: a function which, for a given posterior position `pos`, computes either
a) a symmetric, **positive definite** matrix acting as the position dependent Riemannian metric (if `metric_map = IdentityMap()`), or
b) a symmetric, **not necessarily positive definite** matrix acting as the position dependent Riemannian metric after being passed through the `metric_map` argument, which will have to ensure that its return value *is* positive definite (like `metric_map = SoftAbsMap(alpha)`),
- `premetric`: a function which, for a given posterior position `pos`, computes either
a) a symmetric, **positive definite** matrix acting as the position dependent Riemannian metric (if `metric_map = IdentityMap()`), or
b) a symmetric, **not necessarily positive definite** matrix acting as the position dependent Riemannian metric after being passed through the `metric_map` argument, which will have to ensure that its return value *is* positive definite (like `metric_map = SoftAbsMap(alpha)`),

@nsiccha
Copy link
Contributor

nsiccha commented Jan 7, 2026

I couldn't however help but feel like they were all doing some small (though useful) tweaks without fundamentally addressing the poor interface design of the current RHMC implementation.

Makes sense to me. I've started looking at the old PR seriously shortly before my vacation last year, but relatively quickly realized that I wouldn't want to merge it as is, and that we'd need to change a few to a lot of things. I haven't finished the last PR because I had been thinking slightly further than what's currently there and about how to share as much code as possible between the different integrators as well, but haven't implemented the solution yet.

I'll have a look at this PR next week when I'm back!

@J-Price-3
Copy link

This sounds like a big improvement. The current implemenation was quite difficult to improve on because the code is quite disjointed. This should make it a lot easier to ensure our performance is good. I'll review the changes and can work on some tests in the next week or so.

@J-Price-3
Copy link

For the SoftAbsEval type, it might make sense to somehow optionally provide the resulting matrix from softabs. The old code was always doing this by default using Q * diagm(softabsλ) * Q', which prevents having to do this multiplication more than once. We shouldn't do this every time though, because some cases don't actually need the full multiplied result.

@THargreaves
Copy link
Author

Makes sense to me. I've started looking at the old PR seriously shortly before my vacation last year, but relatively quickly realized that I wouldn't want to merge it as is, and that we'd need to change a few to a lot of things. I haven't finished the last PR because I had been thinking slightly further than what's currently there and about how to share as much code as possible between the different integrators as well, but haven't implemented the solution yet.

I'll have a look at this PR next week when I'm back!

Glad to hear we're on the same page about a rewrite being needed. Looking forward to seeing your thoughts!

@THargreaves
Copy link
Author

For the SoftAbsEval type, it might make sense to somehow optionally provide the resulting matrix from softabs. The old code was always doing this by default using Q * diagm(softabsλ) * Q', which prevents having to do this multiplication more than once. We shouldn't do this every time though, because some cases don't actually need the full multiplied result.

That's a good thing to consider but I don't think we ever actually use the reconstructed matrix.

I think this was only returned by the old softabs function so that ∂H∂r could do G \ r with dense G (by laziness since Julia doesn't implemented an linear algebra compatible eigendecomp). Now we use SoftAbsEval directly to solve G \ r much more efficiently using

function Base.:\(G::SoftAbsEval, p::AbstractVector)
    return G.Q * ((G.Q' * p) ./ G.softabsλ)
end

@nsiccha
Copy link
Contributor

nsiccha commented Jan 12, 2026

@THargreaves and @J-Price-3: I think the general pattern that appears here (but which I've encountered elsewhere as well) is that a) there's a simple dependence structure between the various intermediate and final QOIs that downstream algorithms need, b) there's no "easy" way to "code up the math as in the paper" and get an efficient implementation, c) there's obvious potential (though in this case probably marginal) benefit in reusing working memory across computations or iterations, and d) the efficient "computations paths" may vary for different downstream algorithms (e.g. generalized leapfrog or implicit midpoint integrators).

I think there should be a potentially easy and general solution to this combination of problems!

@THargreaves
Copy link
Author

@THargreaves and @J-Price-3: I think the general pattern that appears here (but which I've encountered elsewhere as well) is that a) there's a simple dependence structure between the various intermediate and final QOIs that downstream algorithms need, b) there's no "easy" way to "code up the math as in the paper" and get an efficient implementation, c) there's obvious potential (though in this case probably marginal) benefit in reusing working memory across computations or iterations, and d) the efficient "computations paths" may vary for different downstream algorithms (e.g. generalized leapfrog or implicit midpoint integrators).

I think there should be a potentially easy and general solution to this combination of problems!

I agree with this assessment, especially point (b). Given the downstream use-cases of Turing, I would lean towards department from the "maths in the paper" to focus on an efficient abstraction.

@nsiccha
Copy link
Contributor

nsiccha commented Jan 12, 2026

I'm putting this in here even though it doesn't relly belong into the PR discussion, but RHMC has been one of the inspirations to think about the below pattern again.

Essentially, I'd like to use Julia macros to be able to specify that dependency graph very concisely, and then have functions/algorithms which compute the needed quantities and only the needed quantities efficiently - essentially reactive programming. I didn't quite find a package which would do this the way I'd like to do it. Maybe there's a good reason for that.

This is for example more or less the functionality I'd like (but probably not the final syntax), first for a Euclidean metric using the standard leapfrog integrators:

@reactive euclidean_phasepoint(pot_f, metric, pos, mom) = begin 
    pot, dpot_dpos = pot_f(pos)

    chol_metric = cholesky(metric)
    # The logdet term could be left out as it doesn't change
    kin = .5 * (@node(logdet(chol_metric)) + dot(mom, dkin_dmom))
    dkin_dmom = chol_metric \ mom

    ham = pot + kin
    dham_dpos = dpot_dpos
    dham_dmom = dkin_dmom
end

@reactive leapfrog!(phasepoint; stepsize) = begin 
    @. phasepoint.mom -= .5 * stepsize * phasepoint.dham_dpos
    @. phasepoint.pos += stepsize * phasepoint.dham_dmom
    @. phasepoint.mom -= .5 * stepsize * phasepoint.dham_dpos
end

The first function definition would specify the parse-/compile-time dependency graph, with @node generating an anonymous node if that makes sense. The second function then uses that dependency graph to appropriately recompute the needed QOIs for every line of code whenever needed.

The same functionality, but for the identity map Riemannian metric together with the generalized leapfrog integrator:

@reactive riemannian_phasepoint(pot_f, metric_f, pos, mom) = begin
    pot, dpot_dpos = pot_f(pos)

    # It would be better to have the interface be to compute (pot, dpot_dpos, metric, metric_grad) in one swoop
    metric, metric_grad = metric_f(pos)
    chol_metric = cholesky(metric)
    inv_metric = Symmetric(inv(chol_metric))
    kin = .5 * (@node(logdet(chol_metric)) + dot(mom, dkin_dmom))
    dkin_dmom = chol_metric \ mom
    dkin_dpos = @node(map(eachindex(pos)) do i
        .5 * tr_prod(inv_metric, metric_grad[:, :, i])
    end) .- Base.broadcasted(eachindex(pos)) do i
        .5 * tr_prod(Base.broadcasted(*, dkin_dmom, dkin_dmom'), metric_grad[:, :, i])
    end

    ham = pot + kin
    dham_dpos = dpot_dpos + dkin_dpos
    dham_dmom = dkin_dmom
end

@reactive generalized_leapfrog!(phasepoint; stepsize, n_fi_steps) = begin 
    pos0, mom0 = map(copy, (phasepoint.pos, phasepoint.mom))
    for _ in 1:n_fi_steps 
        @. phasepoint.mom = mom0 - .5 * stepsize * phasepoint.dham_dpos
    end
    dham_dmom0 = copy(phasepoint.dham_dmom)
    for _ in 1:n_fi_steps 
        @. phasepoint.pos = pos0 + .5 * stepsize * (dham_dmom0 + phasepoint.dham_dmom)
    end
    @. phasepoint.mom -= .5 * stepsize * phasepoint.dham_dpos
end

The implementation for the implicit midpoint integrator would then be

@reactive implicit_midpoint!(phasepoint; stepsize, n_fi_steps) = begin 
    pos0, mom0 = map(copy, (phasepoint.pos, phasepoint.mom))
    for _ in 1:n_fi_steps 
        (;dham_dmom, dham_dpos) = phasepoint
        @. phasepoint.pos = pos0 + .5 * stepsize * dham_dmom
        @. phasepoint.mom = mom0 - .5 * stepsize * dham_dpos
    end
    @. phasepoint.pos = 2 * phasepoint.pos - pos0
    @. phasepoint.mom = 2 * phasepoint.mom - mom0
end

And finally, to make all of the above work with the softabs map:

@reactive riemannian_phasepoint(pot_f, premetric_f, metric_f::typeof(softabs), pos, mom) = begin 
    pot, dpot_dpos = pot_f(pos)

    # It would be better to have the interface be to compute (pot, dpot_dpos, premetric, premetric_grad) in one swoop
    premetric, premetric_grad = premetric_f(pos)
    metric, Q, premetric_eigvals, metric_eigvals = metric_f(premetric)

    # inv(metric) could and should be avoided here - this is just for demo purposes
    inv_metric = inv(metric)
    dkin_dmom = inv_metric * mom
    kin = -.5 * (@node(sum(log, metric_eigvals)) + dot(mom, dkin_dmom))
    J = make_J(premetric_eigvals, alpha)
    RJ = Diagonal(diagonal(J) ./ metric_eigvals)
    D = Diagonal((Q' * mom) ./ metric_eigvals)
    dkin_dpos = @node(map(1:dim) do i
        .5 * tr_prod(@node(Q * RJ * Q'), premetric_grad[:, :, i])
    end) - @node(map(1:dim) do i
        .5 * tr_prod(@node(Q * (D * J * D) * Q'), premetric_grad[:, :, i])
    end)

    ham = pot + kin
    dham_dpos = dpot_dpos + dkin_dpos
    dham_dmom = dkin_dmom
end

Thoughts, @THargreaves or @J-Price-3?


Notes: I think some signs are wrong, and some of the algorithms may also be wrong - I didn't take care to double check everything for this syntax-POC.

@@ -0,0 +1,342 @@
import LinearAlgebra
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
import LinearAlgebra
using LinearAlgebra: LinearAlgebra

@github-actions
Copy link
Contributor

AdvancedHMC.jl documentation for PR #485 is available at:
https://TuringLang.github.io/AdvancedHMC.jl/previews/PR485/

@codecov
Copy link

codecov bot commented Jan 13, 2026

Codecov Report

❌ Patch coverage is 95.80420% with 6 lines in your changes missing coverage. Please review.
✅ Project coverage is 90.68%. Comparing base (6bc0c74) to head (6494b34).

Files with missing lines Patch % Lines
src/riemannian/metric.jl 94.50% 5 Missing ⚠️
src/riemannian/hamiltonian.jl 97.77% 1 Missing ⚠️
Additional details and impacted files
@@             Coverage Diff             @@
##             main     #485       +/-   ##
===========================================
+ Coverage   77.71%   90.68%   +12.96%     
===========================================
  Files          21       22        +1     
  Lines        1270     1288       +18     
===========================================
+ Hits          987     1168      +181     
+ Misses        283      120      -163     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@nsiccha
Copy link
Contributor

nsiccha commented Jan 14, 2026

@THargreaves and @J-Price-3, the linked file below includes everything that's needed for rhmc with the generalized leapfrog integrator or the implicit midpoint integrator, with the linked lines including everything for the softabs metric: https://github.com/nsiccha/ReactiveObjects.jl/blob/899f2e84b9e6f5dc1e00898b07da6b97137dd51c/docs/main.jl#L120-L150

This was referenced Jan 16, 2026
@THargreaves
Copy link
Author

@J-Price-3 Thanks for adding the tests!

@nsiccha I'm now happy with the high-level design of things (minus some of the efficiency/caching improvements that you suggested).

I think there are some places that I've left a bit messy just to get the code working. E.g. some potential type instability here:

Base.eltype(::RiemannianMetric) = Any # Will use eltype(θ) as fallback

Now that I'm happy this is working, I'm going to use this branch to try to get something sorted for ICML and then I can come back to clean it up once that deadline has passed. Probably best to get everything all neat and finalised before moving into main.

@J-Price-3
Copy link

After updating the validity tests, it seems that there may be a problem with the MacOS version. There is a test that seems to consistently fail even after altering hyperparameters. I have confirmed that this should never fail (on windows at least) by using lots of RNG seeds.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants