Conversation
| - `premetric`: a function which, for a given posterior position `pos`, computes either | ||
| a) a symmetric, **positive definite** matrix acting as the position dependent Riemannian metric (if `metric_map = IdentityMap()`), or | ||
| b) a symmetric, **not necessarily positive definite** matrix acting as the position dependent Riemannian metric after being passed through the `metric_map` argument, which will have to ensure that its return value *is* positive definite (like `metric_map = SoftAbsMap(alpha)`), |
There was a problem hiding this comment.
[JuliaFormatter] reported by reviewdog 🐶
| - `premetric`: a function which, for a given posterior position `pos`, computes either | |
| a) a symmetric, **positive definite** matrix acting as the position dependent Riemannian metric (if `metric_map = IdentityMap()`), or | |
| b) a symmetric, **not necessarily positive definite** matrix acting as the position dependent Riemannian metric after being passed through the `metric_map` argument, which will have to ensure that its return value *is* positive definite (like `metric_map = SoftAbsMap(alpha)`), | |
| - `premetric`: a function which, for a given posterior position `pos`, computes either | |
| a) a symmetric, **positive definite** matrix acting as the position dependent Riemannian metric (if `metric_map = IdentityMap()`), or | |
| b) a symmetric, **not necessarily positive definite** matrix acting as the position dependent Riemannian metric after being passed through the `metric_map` argument, which will have to ensure that its return value *is* positive definite (like `metric_map = SoftAbsMap(alpha)`), |
Makes sense to me. I've started looking at the old PR seriously shortly before my vacation last year, but relatively quickly realized that I wouldn't want to merge it as is, and that we'd need to change a few to a lot of things. I haven't finished the last PR because I had been thinking slightly further than what's currently there and about how to share as much code as possible between the different integrators as well, but haven't implemented the solution yet. I'll have a look at this PR next week when I'm back! |
|
This sounds like a big improvement. The current implemenation was quite difficult to improve on because the code is quite disjointed. This should make it a lot easier to ensure our performance is good. I'll review the changes and can work on some tests in the next week or so. |
|
For the |
Glad to hear we're on the same page about a rewrite being needed. Looking forward to seeing your thoughts! |
That's a good thing to consider but I don't think we ever actually use the reconstructed matrix. I think this was only returned by the old |
|
@THargreaves and @J-Price-3: I think the general pattern that appears here (but which I've encountered elsewhere as well) is that a) there's a simple dependence structure between the various intermediate and final QOIs that downstream algorithms need, b) there's no "easy" way to "code up the math as in the paper" and get an efficient implementation, c) there's obvious potential (though in this case probably marginal) benefit in reusing working memory across computations or iterations, and d) the efficient "computations paths" may vary for different downstream algorithms (e.g. generalized leapfrog or implicit midpoint integrators). I think there should be a potentially easy and general solution to this combination of problems! |
I agree with this assessment, especially point (b). Given the downstream use-cases of Turing, I would lean towards department from the "maths in the paper" to focus on an efficient abstraction. |
|
I'm putting this in here even though it doesn't relly belong into the PR discussion, but RHMC has been one of the inspirations to think about the below pattern again. Essentially, I'd like to use Julia macros to be able to specify that dependency graph very concisely, and then have functions/algorithms which compute the needed quantities and only the needed quantities efficiently - essentially reactive programming. I didn't quite find a package which would do this the way I'd like to do it. Maybe there's a good reason for that. This is for example more or less the functionality I'd like (but probably not the final syntax), first for a Euclidean metric using the standard leapfrog integrators: @reactive euclidean_phasepoint(pot_f, metric, pos, mom) = begin
pot, dpot_dpos = pot_f(pos)
chol_metric = cholesky(metric)
# The logdet term could be left out as it doesn't change
kin = .5 * (@node(logdet(chol_metric)) + dot(mom, dkin_dmom))
dkin_dmom = chol_metric \ mom
ham = pot + kin
dham_dpos = dpot_dpos
dham_dmom = dkin_dmom
end
@reactive leapfrog!(phasepoint; stepsize) = begin
@. phasepoint.mom -= .5 * stepsize * phasepoint.dham_dpos
@. phasepoint.pos += stepsize * phasepoint.dham_dmom
@. phasepoint.mom -= .5 * stepsize * phasepoint.dham_dpos
endThe first function definition would specify the parse-/compile-time dependency graph, with The same functionality, but for the identity map Riemannian metric together with the generalized leapfrog integrator: @reactive riemannian_phasepoint(pot_f, metric_f, pos, mom) = begin
pot, dpot_dpos = pot_f(pos)
# It would be better to have the interface be to compute (pot, dpot_dpos, metric, metric_grad) in one swoop
metric, metric_grad = metric_f(pos)
chol_metric = cholesky(metric)
inv_metric = Symmetric(inv(chol_metric))
kin = .5 * (@node(logdet(chol_metric)) + dot(mom, dkin_dmom))
dkin_dmom = chol_metric \ mom
dkin_dpos = @node(map(eachindex(pos)) do i
.5 * tr_prod(inv_metric, metric_grad[:, :, i])
end) .- Base.broadcasted(eachindex(pos)) do i
.5 * tr_prod(Base.broadcasted(*, dkin_dmom, dkin_dmom'), metric_grad[:, :, i])
end
ham = pot + kin
dham_dpos = dpot_dpos + dkin_dpos
dham_dmom = dkin_dmom
end
@reactive generalized_leapfrog!(phasepoint; stepsize, n_fi_steps) = begin
pos0, mom0 = map(copy, (phasepoint.pos, phasepoint.mom))
for _ in 1:n_fi_steps
@. phasepoint.mom = mom0 - .5 * stepsize * phasepoint.dham_dpos
end
dham_dmom0 = copy(phasepoint.dham_dmom)
for _ in 1:n_fi_steps
@. phasepoint.pos = pos0 + .5 * stepsize * (dham_dmom0 + phasepoint.dham_dmom)
end
@. phasepoint.mom -= .5 * stepsize * phasepoint.dham_dpos
endThe implementation for the implicit midpoint integrator would then be @reactive implicit_midpoint!(phasepoint; stepsize, n_fi_steps) = begin
pos0, mom0 = map(copy, (phasepoint.pos, phasepoint.mom))
for _ in 1:n_fi_steps
(;dham_dmom, dham_dpos) = phasepoint
@. phasepoint.pos = pos0 + .5 * stepsize * dham_dmom
@. phasepoint.mom = mom0 - .5 * stepsize * dham_dpos
end
@. phasepoint.pos = 2 * phasepoint.pos - pos0
@. phasepoint.mom = 2 * phasepoint.mom - mom0
endAnd finally, to make all of the above work with the softabs map: @reactive riemannian_phasepoint(pot_f, premetric_f, metric_f::typeof(softabs), pos, mom) = begin
pot, dpot_dpos = pot_f(pos)
# It would be better to have the interface be to compute (pot, dpot_dpos, premetric, premetric_grad) in one swoop
premetric, premetric_grad = premetric_f(pos)
metric, Q, premetric_eigvals, metric_eigvals = metric_f(premetric)
# inv(metric) could and should be avoided here - this is just for demo purposes
inv_metric = inv(metric)
dkin_dmom = inv_metric * mom
kin = -.5 * (@node(sum(log, metric_eigvals)) + dot(mom, dkin_dmom))
J = make_J(premetric_eigvals, alpha)
RJ = Diagonal(diagonal(J) ./ metric_eigvals)
D = Diagonal((Q' * mom) ./ metric_eigvals)
dkin_dpos = @node(map(1:dim) do i
.5 * tr_prod(@node(Q * RJ * Q'), premetric_grad[:, :, i])
end) - @node(map(1:dim) do i
.5 * tr_prod(@node(Q * (D * J * D) * Q'), premetric_grad[:, :, i])
end)
ham = pot + kin
dham_dpos = dpot_dpos + dkin_dpos
dham_dmom = dkin_dmom
endThoughts, @THargreaves or @J-Price-3? Notes: I think some signs are wrong, and some of the algorithms may also be wrong - I didn't take care to double check everything for this syntax-POC. |
| @@ -0,0 +1,342 @@ | |||
| import LinearAlgebra | |||
There was a problem hiding this comment.
[JuliaFormatter] reported by reviewdog 🐶
| import LinearAlgebra | |
| using LinearAlgebra: LinearAlgebra |
|
AdvancedHMC.jl documentation for PR #485 is available at: |
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #485 +/- ##
===========================================
+ Coverage 77.71% 90.68% +12.96%
===========================================
Files 21 22 +1
Lines 1270 1288 +18
===========================================
+ Hits 987 1168 +181
+ Misses 283 120 -163 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
|
@THargreaves and @J-Price-3, the linked file below includes everything that's needed for rhmc with the generalized leapfrog integrator or the implicit midpoint integrator, with the linked lines including everything for the softabs metric: https://github.com/nsiccha/ReactiveObjects.jl/blob/899f2e84b9e6f5dc1e00898b07da6b97137dd51c/docs/main.jl#L120-L150 |
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
|
@J-Price-3 Thanks for adding the tests! @nsiccha I'm now happy with the high-level design of things (minus some of the efficiency/caching improvements that you suggested). I think there are some places that I've left a bit messy just to get the code working. E.g. some potential type instability here: AdvancedHMC.jl/src/riemannian/metric.jl Line 256 in 7e91495 Now that I'm happy this is working, I'm going to use this branch to try to get something sorted for ICML and then I can come back to clean it up once that deadline has passed. Probably best to get everything all neat and finalised before moving into main. |
|
After updating the validity tests, it seems that there may be a problem with the MacOS version. There is a test that seems to consistently fail even after altering hyperparameters. I have confirmed that this should never fail (on windows at least) by using lots of RNG seeds. |
Overview
Insert relevant XKCD — many apologies guys but I felt like another RHMC PR was needed.
I've had a good look through Qingyu, Niko and Jamie's and found them to be a great start with some useful ideas on tidying up the RHMC interface.
I couldn't however help but feel like they were all doing some small (though useful) tweaks without fundamentally addressing the poor interface design of the current RHMC implementation.
This PR builds on Niko's PR (which builds on Yingqu's PR) whilst pulling in some ideas from Jamie's, whilst fundamentally reshaping the RHMC interface so it is philosophically sound.
The high-level takeaway of these changes is that RHMC now works with both SoftAbs eigendecomposition of G as well as Cholesky decompositions of guaranteed PD G, whilst sharing 90% of their code paths.
The key insight is that the gradient computation can be unified by introducing two dispatch functions (
logdet_grad_matrixandkinetic_grad_matrix) that return matrices encoding the appropriate chain rule for each metric type.We also introduce the notion of a pre-metric — for PD metric this is just G, for SoftAbs this is H. This abstraction allows the SoftAbs hamiltonian derivatives to be computed with optimal caching (since we can pull out the derivative part of G into
logdet_grad_matrixleaving just H behind as in Betancourt's paper).I've added tests for energy conservation and verifying that the samplers run but it would be nice to add some validity tests too. Could this be something @J-Price-3 could look at?
Please feel free to tear this PR apart; it's just a draft and I'm open to changes.
New Types
RiemannianMetricFor user-provided positive-definite metrics (e.g., Fisher information matrix):
SoftAbsRiemannianMetricFor Hessian-based metrics with SoftAbs regularization:
SoftAbsEvalCached result of evaluating a SoftAbs metric, storing:
Q: Eigenvectorssoftabsλ: Transformed eigenvalues λ * coth(αλ)J: Divided difference matrix for gradient chain ruleM_logdet: Precomputed matrix for logdet gradientKey Design Decisions
Unified Gradient Computation
The Hamiltonian gradient
∂H/∂θhas three terms:-∂ℓπ/∂θ0.5 * tr(M_logdet * ∂P/∂θ)-0.5 * tr(M_kinetic * ∂P/∂θ)Where P is the "pre-metric" (G for
RiemannianMetric, H forSoftAbsRiemannianMetric).The dispatch functions
logdet_grad_matrix(G)andkinetic_grad_matrix(G, r)return M_logdet and M_kinetic respectively, with the J matrix (divided difference formula) absorbed to handle the chain rule through the SoftAbs transformation.O(n³) Complexity
By precomputing J-dependent matrices, we maintain O(n³) complexity per gradient evaluation rather than O(n⁴). The trace products are computed efficiently using
tr(A*B) = sum(A' .* B)in O(n²).Caching for Fixed-Point Iterations
The generalized leapfrog integrator requires fixed-point iterations where θ is fixed but r varies.
RiemannianGradCachestores all θ-dependent computations (eigendecomposition, logdet gradient terms) to avoid redundant work.Shared Metric Evaluation
neg_energyand∂H∂rboth need the evaluated metric. An optionalG_evalparameter allows sharing the computation, andphasepointuses this to compute both efficiently.Deprecations
The following types are deprecated with warnings:
DenseRiemannianMetric→ UseRiemannianMetricorSoftAbsRiemannianMetricIdentityMap→ UseRiemannianMetricdirectlySoftAbsMap→ UseSoftAbsRiemannianMetricExisting code using the deprecated API will continue to work but will emit deprecation warnings — I'd be quite keen to remove these in the long-run.
References