Skip to content

Commit 7ea27f3

Browse files
authored
Allow ADTypes AD backend selection in Hamiltonian (#405)
* Allow ADTypes AD backend selection in Hamiltonian * Apply reviews * Fix CI failings * Integrate DifferentiationInterface * Fix typos and remove LDPAD from deps * Lets only focus on adtypes * Fix some typos * Remove default AD choose * Misc * Also test when provided as LogDensityModel * Dont forget other dependencies for ADTypes ext * Proper using * Proper using
1 parent a96ab41 commit 7ea27f3

File tree

4 files changed

+69
-3
lines changed

4 files changed

+69
-3
lines changed

Project.toml

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,34 +17,38 @@ StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
1717
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
1818

1919
[weakdeps]
20+
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
2021
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
2122
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
2223
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
2324

2425
[extensions]
26+
AdvancedHMCADTypesExt = "ADTypes"
2527
AdvancedHMCCUDAExt = "CUDA"
2628
AdvancedHMCMCMCChainsExt = "MCMCChains"
2729
AdvancedHMCOrdinaryDiffEqExt = "OrdinaryDiffEq"
2830

2931
[compat]
32+
ADTypes = "1"
3033
AbstractMCMC = "5.6"
3134
ArgCheck = "1, 2"
3235
CUDA = "3, 4, 5"
3336
DocStringExtensions = "0.8, 0.9"
37+
LinearAlgebra = "<0.1, 1"
3438
LogDensityProblems = "2"
3539
LogDensityProblemsAD = "1"
3640
MCMCChains = "5, 6"
3741
OrdinaryDiffEq = "6"
3842
ProgressMeter = "1"
43+
Random = "<0.1, 1"
3944
Setfield = "0.7, 0.8, 1"
4045
Statistics = "1.6"
4146
StatsBase = "0.31, 0.32, 0.33, 0.34"
4247
StatsFuns = "0.8, 0.9, 1"
43-
LinearAlgebra = "<0.1, 1"
44-
Random = "<0.1, 1"
4548
julia = "1.10"
4649

4750
[extras]
51+
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
4852
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
4953
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
5054
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"

ext/AdvancedHMCADTypesExt.jl

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
module AdvancedHMCADTypesExt
2+
3+
using AdvancedHMC:
4+
AbstractMetric, LogDensityModel, Hamiltonian, LogDensityProblems, LogDensityProblemsAD
5+
using ADTypes: AbstractADType
6+
7+
function Hamiltonian(
8+
metric::AbstractMetric, ℓπ::LogDensityModel, kind::AbstractADType; kwargs...
9+
)
10+
return Hamiltonian(metric, ℓπ.logdensity, kind; kwargs...)
11+
end
12+
function Hamiltonian(metric::AbstractMetric, ℓπ, kind::AbstractADType; kwargs...)
13+
if LogDensityProblems.capabilities(ℓπ) === nothing
14+
throw(
15+
ArgumentError(
16+
"The log density function does not support the LogDensityProblems.jl interface",
17+
),
18+
)
19+
end
20+
= LogDensityProblemsAD.ADgradient(kind, ℓπ; kwargs...)
21+
return Hamiltonian(metric, ℓ)
22+
end
23+
24+
end

test/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
[deps]
2+
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
23
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
34
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
45
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"

test/demo.jl

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
using ReTest
22
using AdvancedHMC, Distributions, ForwardDiff, ComponentArrays, AbstractMCMC
3-
using LinearAlgebra
3+
using LinearAlgebra, ADTypes
44

55
@testset "Demo" begin
66
# Define the target distribution using the `LogDensityProblem` interface
@@ -105,3 +105,40 @@ end
105105
@test "μ" labels
106106
@test "σ" labels
107107
end
108+
109+
@testset "ADTypes" begin
110+
# Set the number of samples to draw and warmup iterations
111+
n_samples, n_adapts = 2_000, 1_000
112+
initial_θ = rand(D)
113+
# Define a Hamiltonian system
114+
metric = DiagEuclideanMetric(D)
115+
116+
hamiltonian_ldp = Hamiltonian(metric, ℓπ_gdemo, AutoForwardDiff())
117+
118+
model = AbstractMCMC.LogDensityModel(ℓπ_gdemo)
119+
hamiltonian_ldm = Hamiltonian(metric, model, AutoForwardDiff())
120+
121+
for hamiltonian in (hamiltonian_ldp, hamiltonian_ldm)
122+
initial_ϵ = find_good_stepsize(hamiltonian, initial_θ)
123+
integrator = Leapfrog(initial_ϵ)
124+
125+
kernel = HMCKernel(Trajectory{MultinomialTS}(integrator, GeneralisedNoUTurn()))
126+
adaptor = StanHMCAdaptor(
127+
MassMatrixAdaptor(metric), StepSizeAdaptor(0.8, integrator)
128+
)
129+
130+
samples, stats = sample(
131+
hamiltonian,
132+
kernel,
133+
initial_θ,
134+
n_samples,
135+
adaptor,
136+
n_adapts;
137+
progress=false,
138+
verbose=false,
139+
)
140+
141+
@test length(samples) == n_samples
142+
@test length(stats) == n_samples
143+
end
144+
end

0 commit comments

Comments
 (0)