Skip to content

Commit 3f4620a

Browse files
xukai92yebaiKai Xugithub-actions[bot]
authored
Implement Riemannian HMC (#305)
* feat: implement dense RHMC * chore: update gitignore * chore: add plain version of notebook for discussion * Re-organising files. * move file * move file * Update README.md * feat: introduce a type for different RHMC variants and fix derivative bug * chore: remove unused code * feat: Gweke test * chore: rerun the code with Julia 1.6.7 * chore: add comments of equation refs * chore: add missing eq * fix: flip sign in dHdx * fix: flip sign in dHdx * chore: add missing deps and comments * fix: flip sign in dHdx for Fisher info metric * fix: use the alternative derivation for the quadratic term in SoftAbs metric * feat: cache terms out of loop * chore: run geweke test * research: double check sign * fix: sign of (13) * refactor: clean up test * test: enable more tests * fix: add braces to make v1 work * chore: keep the cache friendly version only * fix: revert debugging change * chore: remove debug code and unsed variables * chore: clean up tests * refactor: update naming * chore: remove unused file * refactor: update interface for code reuse * perf: cache unchanged terms in generalized LF * chore: notebooks for analysis and validation * perf: use ForwardDiff to compute ∂G∂θ * chore: rerun notebooks * Rename riemannian_hmc-sampler.jl to riemannian_hmc_sampler.jl * Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Rename riemannian_hmc_sampler.jl to riemannian_hmc_utility.jl * moved jupyter notebooks into new folder * some updates * fix incorrect filename * fix formatting * add missing ReverseDiff dep * add missing PyCall to test dep * add missing dep * update notebooks * fix test dep * Update research/tests/runtests.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update research/tests/runtests.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * fix res dep --------- Co-authored-by: Kai Xu <[email protected]> Co-authored-by: Hong Ge <[email protected]> Co-authored-by: Hong Ge <[email protected]> Co-authored-by: Kai Xu <[email protected]> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent ab0b078 commit 3f4620a

13 files changed

+2554
-11
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
.vscode
22
.history
33
.DS_Store
4+
.ipynb_checkpoints
45
Manifest.toml
6+
!test/experimental/Manifest.toml

research/Project.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,14 @@ AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
33
AdaptiveRejectionSampling = "c75e803d-635f-53bd-ab7d-544e482d8c75"
44
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"
55
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
6+
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
7+
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
68
InplaceOps = "505f98c9-085e-5b2c-8e89-488be7bf1f34"
79
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
810
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
911
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1012
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
13+
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
1114
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
1215
SimpleUnPack = "ce78b400-467f-4804-87d8-8f486da07d0a"
1316
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

research/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,6 @@ In order to use algorithms in this folder, please navigate to the AdvancedHMC fo
33

44
```
55
] activate research/
6-
] develop src/
6+
] develop .
77
] instantiate
88
```

research/notebooks/geweke_test.ipynb

Lines changed: 1266 additions & 0 deletions
Large diffs are not rendered by default.

research/notebooks/riemannian_hmc.ipynb

Lines changed: 528 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": 1,
6+
"id": "1d21a625-bc05-4cf9-80a4-04debf3ee038",
7+
"metadata": {},
8+
"outputs": [
9+
{
10+
"name": "stderr",
11+
"output_type": "stream",
12+
"text": [
13+
"\u001b[32m\u001b[1m Activating\u001b[22m\u001b[39m project at `~/.julia/dev/AdvancedHMC/research/tests`\n"
14+
]
15+
},
16+
{
17+
"name": "stdout",
18+
"output_type": "stream",
19+
"text": [
20+
"Julia Version 1.9.0-rc2\n",
21+
"Commit 72aec423c2a (2023-04-01 10:41 UTC)\n",
22+
"Platform Info:\n",
23+
" OS: macOS (x86_64-apple-darwin21.4.0)\n",
24+
" CPU: 20 × Apple M1 Ultra\n",
25+
" WORD_SIZE: 64\n",
26+
" LIBM: libopenlibm\n",
27+
" LLVM: libLLVM-14.0.6 (ORCJIT, westmere)\n",
28+
" Threads: 2 on 20 virtual cores\n",
29+
"\u001b[32m\u001b[1mStatus\u001b[22m\u001b[39m `~/.julia/dev/AdvancedHMC/research/tests/Project.toml`\n",
30+
" \u001b[90m[c75e803d] \u001b[39mAdaptiveRejectionSampling v0.1.1\n",
31+
" \u001b[90m[0bf59076] \u001b[39mAdvancedHMC v0.4.5\n",
32+
" \u001b[90m[6e4b80f9] \u001b[39mBenchmarkTools v1.3.2\n",
33+
" \u001b[90m[863f3e99] \u001b[39mComonicon v1.0.4\n",
34+
" \u001b[90m[163ba53b] \u001b[39mDiffResults v1.1.0\n",
35+
" \u001b[90m[31c24e10] \u001b[39mDistributions v0.25.87\n",
36+
" \u001b[90m[366bfd00] \u001b[39mDynamicPPL v0.22.2\n",
37+
" \u001b[90m[6a86dc24] \u001b[39mFiniteDiff v2.19.0\n",
38+
" \u001b[90m[f6369f11] \u001b[39mForwardDiff v0.10.35\n",
39+
" \u001b[90m[6d524b87] \u001b[39mMCMCDebugging v0.2.1 `https://github.com/TuringLang/MCMCDebugging.jl#master`\n",
40+
" \u001b[90m[91a5bcdd] \u001b[39mPlots v1.38.9\n",
41+
" \u001b[90m[d330b81b] \u001b[39mPyPlot v2.11.1\n",
42+
" \u001b[90m[e0db7c4e] \u001b[39mReTest v0.3.2\n",
43+
" \u001b[90m[37e2e3b7] \u001b[39mReverseDiff v1.14.4\n",
44+
" \u001b[90m[a8a75453] \u001b[39mStatProfilerHTML v1.5.0\n",
45+
" \u001b[90m[8a639fad] \u001b[39mVecTargets v0.2.0 `https://github.com/xukai92/VecTargets.jl#main`\n"
46+
]
47+
}
48+
],
49+
"source": [
50+
"using InteractiveUtils, Pkg\n",
51+
"using AdvancedHMC; Pkg.activate(pkgdir(AdvancedHMC) * \"/research/tests\"); Pkg.instantiate()\n",
52+
"versioninfo(); Pkg.status()"
53+
]
54+
},
55+
{
56+
"cell_type": "code",
57+
"execution_count": 2,
58+
"id": "3d0c7de5-8dc2-477a-bde7-57ebc7bcb671",
59+
"metadata": {},
60+
"outputs": [],
61+
"source": [
62+
"using LinearAlgebra, FiniteDiff, VecTargets"
63+
]
64+
},
65+
{
66+
"cell_type": "code",
67+
"execution_count": 3,
68+
"id": "017c5260-fc75-41b5-b886-f8ab3ad52432",
69+
"metadata": {},
70+
"outputs": [
71+
{
72+
"data": {
73+
"text/plain": [
74+
"#9 (generic function with 1 method)"
75+
]
76+
},
77+
"execution_count": 3,
78+
"metadata": {},
79+
"output_type": "execute_result"
80+
}
81+
],
82+
"source": [
83+
"target = Funnel()\n",
84+
"\n",
85+
"ℓπ = x -> logpdf(target, x)\n",
86+
"neg_ℓπ = x -> -logpdf(target, x)\n",
87+
"\n",
88+
"H = x -> VecTargets.gen_hess(ℓπ, x)(x)[3]\n",
89+
"G = x -> VecTargets.gen_hess(neg_ℓπ, x)(x)[3]"
90+
]
91+
},
92+
{
93+
"cell_type": "code",
94+
"execution_count": 4,
95+
"id": "2f694d5b-54e8-4793-8999-a0ffec9eae0b",
96+
"metadata": {},
97+
"outputs": [
98+
{
99+
"name": "stdout",
100+
"output_type": "stream",
101+
"text": [
102+
"H(xt) = [-0.1111111111111111 0.0; 0.0 -1.0]\n",
103+
"G(xt) = [0.1111111111111111 0.0; 0.0 1.0]\n",
104+
"FiniteDiff.finite_difference_gradient(Hamiltonian_partial, xt) = [-0.9999999999502168, 0.0]\n",
105+
"[tr(inv(G(xt)) * Jt[1:2, 1:2]), tr(inv(G(xt)) * Jt[3:4, 1:2])] = [-0.9999999664723873, 0.0]\n"
106+
]
107+
},
108+
{
109+
"data": {
110+
"text/plain": [
111+
"2-element Vector{Float64}:\n",
112+
" -0.9999999664723873\n",
113+
" 0.0"
114+
]
115+
},
116+
"execution_count": 4,
117+
"metadata": {},
118+
"output_type": "execute_result"
119+
}
120+
],
121+
"source": [
122+
"xt = [0.0, 0.0] # x test\n",
123+
"\n",
124+
"@show H(xt) G(xt)\n",
125+
"\n",
126+
"# Hamiltonian_partial(x) = begin x\n",
127+
"# lad, s = logabsdet(G(x))\n",
128+
"# lad * s\n",
129+
"# end # WRONG implementation of the second term of (13)\n",
130+
" # `s` returned is the sign of `det(G)`, not the whole thing\n",
131+
"Hamiltonian_partial(x) = begin x\n",
132+
" logdet(G(x))\n",
133+
"end # second term of (13)\n",
134+
"\n",
135+
"@show FiniteDiff.finite_difference_gradient(Hamiltonian_partial, xt)\n",
136+
"\n",
137+
"Jt = FiniteDiff.finite_difference_jacobian(G, xt)\n",
138+
"@show [tr(inv(G(xt)) * Jt[1:2,1:2]), tr(inv(G(xt)) * Jt[3:4,1:2])]"
139+
]
140+
},
141+
{
142+
"cell_type": "code",
143+
"execution_count": null,
144+
"id": "c5ec162e-0114-44a7-b23f-3366286c5414",
145+
"metadata": {},
146+
"outputs": [],
147+
"source": []
148+
}
149+
],
150+
"metadata": {
151+
"kernelspec": {
152+
"display_name": "julia 1.9.0-rc2",
153+
"language": "julia",
154+
"name": "julia-1.9"
155+
},
156+
"language_info": {
157+
"file_extension": ".jl",
158+
"mimetype": "application/julia",
159+
"name": "julia",
160+
"version": "1.9.0"
161+
}
162+
},
163+
"nbformat": 4,
164+
"nbformat_minor": 5
165+
}

research/src/relativistic_hmc.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ end
5353

5454

5555
using AdaptiveRejectionSampling: RejectionSampler, run_sampler!
56+
import AdvancedHMC: _rand
5657

5758
# TODO Support AbstractVector{<:AbstractRNG}
5859
function _rand(

0 commit comments

Comments
 (0)