Skip to content

Commit fe9bee2

Browse files
devmotiontorfjelde
andauthored
Use iteration interface of DynamicHMC (#1497)
* Use iteration interface of DynamicHMC * Increase number of samples * Remove `@debug` messages (add them upstream!) * Update src/contrib/inference/dynamichmc.jl Co-authored-by: Tor Erlend Fjelde <[email protected]> Co-authored-by: Tor Erlend Fjelde <[email protected]>
1 parent 181aca8 commit fe9bee2

File tree

5 files changed

+114
-123
lines changed

5 files changed

+114
-123
lines changed

Project.toml

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "Turing"
22
uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
3-
version = "0.15.5"
3+
version = "0.15.6"
44

55
[deps]
66
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
@@ -18,7 +18,6 @@ EllipticalSliceSampling = "cad2338a-1db2-11e9-3401-43bc07c9ede2"
1818
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
1919
Libtask = "6f1fad26-d15e-5dc8-ae53-837a1d7b8c9f"
2020
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
21-
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
2221
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
2322
NamedArrays = "86f7a689-2022-50b4-a561-43c23ac3c673"
2423
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
@@ -47,7 +46,6 @@ DynamicPPL = "0.10.2"
4746
EllipticalSliceSampling = "0.4"
4847
ForwardDiff = "0.10.3"
4948
Libtask = "0.4, 0.5"
50-
LogDensityProblems = "^0.9, 0.10"
5149
MCMCChains = "4"
5250
NamedArrays = "0.9"
5351
Reexport = "0.2.0"

docs/src/using-turing/dynamichmc.md

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,18 +6,14 @@ title: Using DynamicHMC
66

77
Turing supports the use of [DynamicHMC](https://github.com/tpapp/DynamicHMC.jl) as a sampler through the `DynamicNUTS` function.
88

9-
10-
`DynamicNUTS` is not appropriate for use in compositional inference. If you intend to use [Gibbs]({{site.baseurl}}/docs/library/#Turing.Inference.Gibbs) sampling, you must use Turing's native `NUTS` function.
11-
12-
139
To use the `DynamicNUTS` function, you must import the `DynamicHMC` package as well as Turing. Turing does not formally require `DynamicHMC` but will include additional functionality if both packages are present.
1410

1511
Here is a brief example of how to apply `DynamicNUTS`:
1612

1713

1814
```julia
1915
# Import Turing and DynamicHMC.
20-
using LogDensityProblems, DynamicHMC, Turing
16+
using DynamicHMC, Turing
2117

2218
# Model definition.
2319
@model function gdemo(x, y)

src/Turing.jl

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -55,14 +55,15 @@ using .Variational
5555
# end
5656
# end
5757

58-
@init @require DynamicHMC="bbc10e6e-7c05-544b-b16e-64fede858acb" @eval Inference begin
59-
import ..DynamicHMC
60-
61-
if isdefined(DynamicHMC, :mcmc_with_warmup)
62-
using ..DynamicHMC: mcmc_with_warmup
63-
include("contrib/inference/dynamichmc.jl")
64-
else
65-
error("Please update DynamicHMC, v1.x is no longer supported")
58+
@init @require DynamicHMC="bbc10e6e-7c05-544b-b16e-64fede858acb" begin
59+
@eval Inference begin
60+
import ..DynamicHMC
61+
62+
if isdefined(DynamicHMC, :mcmc_with_warmup)
63+
include("contrib/inference/dynamichmc.jl")
64+
else
65+
error("Please update DynamicHMC, v1.x is no longer supported")
66+
end
6667
end
6768
end
6869

src/contrib/inference/dynamichmc.jl

Lines changed: 95 additions & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -1,52 +1,64 @@
11
###
22
### DynamicHMC backend - https://github.com/tpapp/DynamicHMC.jl
33
###
4-
struct DynamicNUTS{AD, space} <: Hamiltonian{AD} end
54

6-
using LogDensityProblems: LogDensityProblems
5+
"""
6+
DynamicNUTS
77
8-
struct FunctionLogDensity{F}
9-
dimension::Int
10-
f::F
11-
end
8+
Dynamic No U-Turn Sampling algorithm provided by the DynamicHMC package.
9+
10+
To use it, make sure you have DynamicHMC package (version >= 2) loaded:
11+
```julia
12+
using DynamicHMC
13+
```
14+
"""
15+
struct DynamicNUTS{AD,space} <: Hamiltonian{AD} end
1216

13-
LogDensityProblems.dimension(ℓ::FunctionLogDensity) =.dimension
17+
DynamicNUTS(args...) = DynamicNUTS{ADBackend()}(args...)
18+
DynamicNUTS{AD}(space::Symbol...) where AD = DynamicNUTS{AD, space}()
1419

15-
function LogDensityProblems.capabilities(::Type{<:FunctionLogDensity})
16-
LogDensityProblems.LogDensityOrder{1}()
20+
DynamicPPL.getspace(::DynamicNUTS{<:Any, space}) where {space} = space
21+
22+
struct DynamicHMCLogDensity{M<:Model,S<:Sampler{<:DynamicNUTS},V<:AbstractVarInfo}
23+
model::M
24+
sampler::S
25+
varinfo::V
1726
end
1827

19-
function LogDensityProblems.logdensity(ℓ::FunctionLogDensity, x::AbstractVector)
20-
first(ℓ.f(x))
28+
function DynamicHMC.dimension(ℓ::DynamicHMCLogDensity)
29+
return length(ℓ.varinfo[ℓ.sampler])
2130
end
2231

23-
function LogDensityProblems.logdensity_and_gradient(ℓ::FunctionLogDensity,
24-
x::AbstractVector)
25-
.f(x)
32+
function DynamicHMC.capabilities(::Type{<:DynamicHMCLogDensity})
33+
return DynamicHMC.LogDensityOrder{1}()
34+
end
35+
36+
function DynamicHMC.logdensity_and_gradient(
37+
::DynamicHMCLogDensity,
38+
x::AbstractVector,
39+
)
40+
return gradient_logp(x, ℓ.varinfo, ℓ.model, ℓ.sampler)
2641
end
2742

2843
"""
29-
DynamicNUTS()
44+
DynamicNUTSState
3045
31-
Dynamic No U-Turn Sampling algorithm provided by the DynamicHMC package. To use it, make
32-
sure you have the DynamicHMC package (version `2.*`) loaded:
46+
State of the [`DynamicNUTS`](@ref) sampler.
3347
34-
```julia
35-
using DynamicHMC
36-
``
48+
# Fields
49+
$(TYPEDFIELDS)
3750
"""
38-
DynamicNUTS(args...) = DynamicNUTS{ADBackend()}(args...)
39-
DynamicNUTS{AD}() where AD = DynamicNUTS{AD, ()}()
40-
function DynamicNUTS{AD}(space::Symbol...) where AD
41-
DynamicNUTS{AD, space}()
42-
end
43-
44-
struct DynamicNUTSState{V<:AbstractVarInfo,D}
51+
struct DynamicNUTSState{V<:AbstractVarInfo,C,M,S}
4552
vi::V
46-
draws::Vector{D}
53+
"Cache of sample, log density, and gradient of log density."
54+
cache::C
55+
metric::M
56+
stepsize::S
4757
end
4858

49-
DynamicPPL.getspace(::DynamicNUTS{<:Any, space}) where {space} = space
59+
function gibbs_update_state(state::DynamicNUTSState, varinfo::AbstractVarInfo)
60+
return DynamicNUTSState(varinfo, state.cache, state.metric, state.stepsize)
61+
end
5062

5163
DynamicPPL.initialsampler(::Sampler{<:DynamicNUTS}) = SampleFromUniform()
5264

@@ -55,44 +67,39 @@ function DynamicPPL.initialstep(
5567
model::Model,
5668
spl::Sampler{<:DynamicNUTS},
5769
vi::AbstractVarInfo;
58-
N::Int,
5970
kwargs...
6071
)
61-
# Set up lp function.
62-
function _lp(x)
63-
gradient_logp(x, vi, model, spl)
64-
end
65-
66-
link!(vi, spl)
67-
l, dl = _lp(vi[spl])
68-
while !isfinite(l) || !isfinite(dl)
69-
model(vi, SampleFromUniform())
70-
link!(vi, spl)
71-
l, dl = _lp(vi[spl])
72-
end
73-
74-
if spl.selector.tag == :default && !islinked(vi, spl)
75-
link!(vi, spl)
76-
model(vi, spl)
72+
# Ensure that initial sample is in unconstrained space.
73+
if !DynamicPPL.islinked(vi, spl)
74+
DynamicPPL.link!(vi, spl)
75+
model(rng, vi, spl)
7776
end
7877

79-
results = mcmc_with_warmup(
78+
# Perform initial step.
79+
results = DynamicHMC.mcmc_keep_warmup(
8080
rng,
81-
FunctionLogDensity(
82-
length(vi[spl]),
83-
_lp
84-
),
85-
N
81+
DynamicHMCLogDensity(model, spl, vi),
82+
0;
83+
initialization = (q = vi[spl],),
84+
reporter = DynamicHMC.NoProgressReport(),
8685
)
87-
draws = results.chain
86+
steps = DynamicHMC.mcmc_steps(results.sampling_logdensity, results.final_warmup_state)
87+
Q, _ = DynamicHMC.mcmc_next_step(steps, results.final_warmup_state.Q)
8888

89-
# Compute first transition and state.
90-
draw = popfirst!(draws)
91-
vi[spl] = draw
92-
transition = Transition(vi)
93-
state = DynamicNUTSState(vi, draws)
89+
# Update the variables.
90+
vi[spl] = Q.q
91+
DynamicPPL.setlogp!(vi, Q.ℓq)
9492

95-
return transition, state
93+
# If a Gibbs component, transform the values back to the constrained space.
94+
if spl.selector.tag !== :default
95+
DynamicPPL.invlink!(vi, spl)
96+
end
97+
98+
# Create first sample and state.
99+
sample = Transition(vi)
100+
state = DynamicNUTSState(vi, Q, steps.H.κ, steps.ϵ)
101+
102+
return sample, state
96103
end
97104

98105
function AbstractMCMC.step(
@@ -102,55 +109,38 @@ function AbstractMCMC.step(
102109
state::DynamicNUTSState;
103110
kwargs...
104111
)
105-
# Extract VarInfo object.
112+
# Compute next sample.
106113
vi = state.vi
107-
108-
# Pop the next draw off the vector.
109-
draw = popfirst!(state.draws)
110-
vi[spl] = draw
111-
112-
# Compute next transition.
113-
transition = Transition(vi)
114-
115-
return transition, state
116-
end
117-
118-
# Disable the progress logging for DynamicHMC, since it has its own progress meter.
119-
function AbstractMCMC.sample(
120-
rng::AbstractRNG,
121-
model::AbstractModel,
122-
alg::DynamicNUTS,
123-
N::Integer;
124-
chain_type=MCMCChains.Chains,
125-
resume_from=nothing,
126-
progress=PROGRESS[],
127-
kwargs...
128-
)
129-
if progress
130-
@warn "[HMC] Progress logging in Turing is disabled since DynamicHMC provides its own progress meter"
131-
end
132-
if resume_from === nothing
133-
return AbstractMCMC.sample(rng, model, Sampler(alg, model), N;
134-
chain_type=chain_type, progress=false, N=N, kwargs...)
114+
= DynamicHMCLogDensity(model, spl, vi)
115+
steps = DynamicHMC.mcmc_steps(
116+
rng,
117+
DynamicHMC.NUTS(),
118+
state.metric,
119+
ℓ,
120+
state.stepsize,
121+
)
122+
Q = if spl.selector.tag !== :default
123+
# When a Gibbs component, transform values to the unconstrained space
124+
# and update the previous evaluation.
125+
DynamicPPL.link!(vi, spl)
126+
DynamicHMC.evaluate_ℓ(ℓ, vi[spl])
135127
else
136-
return resume(resume_from, N; chain_type=chain_type, progress=false, N=N, kwargs...)
128+
state.cache
137129
end
138-
end
130+
newQ, _ = DynamicHMC.mcmc_next_step(steps, Q)
139131

140-
function AbstractMCMC.sample(
141-
rng::AbstractRNG,
142-
model::AbstractModel,
143-
alg::DynamicNUTS,
144-
parallel::AbstractMCMC.AbstractMCMCParallel,
145-
N::Integer,
146-
n_chains::Integer;
147-
chain_type=MCMCChains.Chains,
148-
progress=PROGRESS[],
149-
kwargs...
150-
)
151-
if progress
152-
@warn "[HMC] Progress logging in Turing is disabled since DynamicHMC provides its own progress meter"
132+
# Update the variables.
133+
vi[spl] = newQ.q
134+
DynamicPPL.setlogp!(vi, newQ.ℓq)
135+
136+
# If a Gibbs component, transform the values back to the constrained space.
137+
if spl.selector.tag !== :default
138+
DynamicPPL.invlink!(vi, spl)
153139
end
154-
return AbstractMCMC.sample(rng, model, Sampler(alg, model), parallel, N, n_chains;
155-
chain_type=chain_type, progress=false, N=N, kwargs...)
140+
141+
# Create next sample and state.
142+
sample = Transition(vi)
143+
newstate = DynamicNUTSState(vi, newQ, state.metric, state.stepsize)
144+
145+
return sample, newstate
156146
end

test/contrib/inference/dynamichmc.jl

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,12 @@ include(dir*"/test/test_utils/AllUtils.jl")
1010

1111
@test DynamicPPL.alg_str(Sampler(DynamicNUTS(), gdemo_default)) == "DynamicNUTS"
1212

13-
chn = sample(gdemo_default, DynamicNUTS(), 5000)
14-
check_numerical(chn, [:s, :m], [49/24, 7/6], atol=0.2)
13+
chn = sample(gdemo_default, DynamicNUTS(), 10_000)
14+
check_gdemo(chn)
15+
16+
chn2 = sample(gdemo_default, Gibbs(PG(15, :s), DynamicNUTS(:m)), 10_000)
17+
check_gdemo(chn2)
18+
19+
chn3 = sample(gdemo_default, Gibbs(DynamicNUTS(:s), ESS(:m)), 10_000)
20+
check_gdemo(chn3)
1521
end

0 commit comments

Comments
 (0)