1
1
# ##
2
2
# ## DynamicHMC backend - https://github.com/tpapp/DynamicHMC.jl
3
3
# ##
4
- struct DynamicNUTS{AD, space} <: Hamiltonian{AD} end
5
4
6
- using LogDensityProblems: LogDensityProblems
5
+ """
6
+ DynamicNUTS
7
7
8
- struct FunctionLogDensity{F}
9
- dimension:: Int
10
- f:: F
11
- end
8
+ Dynamic No U-Turn Sampling algorithm provided by the DynamicHMC package.
9
+
10
+ To use it, make sure you have DynamicHMC package (version >= 2) loaded:
11
+ ```julia
12
+ using DynamicHMC
13
+ ```
14
+ """
15
+ struct DynamicNUTS{AD,space} <: Hamiltonian{AD} end
12
16
13
- LogDensityProblems. dimension (ℓ:: FunctionLogDensity ) = ℓ. dimension
17
+ DynamicNUTS (args... ) = DynamicNUTS {ADBackend()} (args... )
18
+ DynamicNUTS {AD} (space:: Symbol... ) where AD = DynamicNUTS {AD, space} ()
14
19
15
- function LogDensityProblems. capabilities (:: Type{<:FunctionLogDensity} )
16
- LogDensityProblems. LogDensityOrder {1} ()
20
+ DynamicPPL. getspace (:: DynamicNUTS{<:Any, space} ) where {space} = space
21
+
22
+ struct DynamicHMCLogDensity{M<: Model ,S<: Sampler{<:DynamicNUTS} ,V<: AbstractVarInfo }
23
+ model:: M
24
+ sampler:: S
25
+ varinfo:: V
17
26
end
18
27
19
- function LogDensityProblems . logdensity (ℓ:: FunctionLogDensity , x :: AbstractVector )
20
- first (ℓ. f (x) )
28
+ function DynamicHMC . dimension (ℓ:: DynamicHMCLogDensity )
29
+ return length (ℓ. varinfo[ℓ . sampler] )
21
30
end
22
31
23
- function LogDensityProblems. logdensity_and_gradient (ℓ:: FunctionLogDensity ,
24
- x:: AbstractVector )
25
- ℓ. f (x)
32
+ function DynamicHMC. capabilities (:: Type{<:DynamicHMCLogDensity} )
33
+ return DynamicHMC. LogDensityOrder {1} ()
34
+ end
35
+
36
+ function DynamicHMC. logdensity_and_gradient (
37
+ ℓ:: DynamicHMCLogDensity ,
38
+ x:: AbstractVector ,
39
+ )
40
+ return gradient_logp (x, ℓ. varinfo, ℓ. model, ℓ. sampler)
26
41
end
27
42
28
43
"""
29
- DynamicNUTS()
44
+ DynamicNUTSState
30
45
31
- Dynamic No U-Turn Sampling algorithm provided by the DynamicHMC package. To use it, make
32
- sure you have the DynamicHMC package (version `2.*`) loaded:
46
+ State of the [`DynamicNUTS`](@ref) sampler.
33
47
34
- ```julia
35
- using DynamicHMC
36
- ``
48
+ # Fields
49
+ $(TYPEDFIELDS)
37
50
"""
38
- DynamicNUTS (args... ) = DynamicNUTS {ADBackend()} (args... )
39
- DynamicNUTS {AD} () where AD = DynamicNUTS {AD, ()} ()
40
- function DynamicNUTS {AD} (space:: Symbol... ) where AD
41
- DynamicNUTS {AD, space} ()
42
- end
43
-
44
- struct DynamicNUTSState{V<: AbstractVarInfo ,D}
51
+ struct DynamicNUTSState{V<: AbstractVarInfo ,C,M,S}
45
52
vi:: V
46
- draws:: Vector{D}
53
+ " Cache of sample, log density, and gradient of log density."
54
+ cache:: C
55
+ metric:: M
56
+ stepsize:: S
47
57
end
48
58
49
- DynamicPPL. getspace (:: DynamicNUTS{<:Any, space} ) where {space} = space
59
+ function gibbs_update_state (state:: DynamicNUTSState , varinfo:: AbstractVarInfo )
60
+ return DynamicNUTSState (varinfo, state. cache, state. metric, state. stepsize)
61
+ end
50
62
51
63
DynamicPPL. initialsampler (:: Sampler{<:DynamicNUTS} ) = SampleFromUniform ()
52
64
@@ -55,44 +67,39 @@ function DynamicPPL.initialstep(
55
67
model:: Model ,
56
68
spl:: Sampler{<:DynamicNUTS} ,
57
69
vi:: AbstractVarInfo ;
58
- N:: Int ,
59
70
kwargs...
60
71
)
61
- # Set up lp function.
62
- function _lp (x)
63
- gradient_logp (x, vi, model, spl)
64
- end
65
-
66
- link! (vi, spl)
67
- l, dl = _lp (vi[spl])
68
- while ! isfinite (l) || ! isfinite (dl)
69
- model (vi, SampleFromUniform ())
70
- link! (vi, spl)
71
- l, dl = _lp (vi[spl])
72
- end
73
-
74
- if spl. selector. tag == :default && ! islinked (vi, spl)
75
- link! (vi, spl)
76
- model (vi, spl)
72
+ # Ensure that initial sample is in unconstrained space.
73
+ if ! DynamicPPL. islinked (vi, spl)
74
+ DynamicPPL. link! (vi, spl)
75
+ model (rng, vi, spl)
77
76
end
78
77
79
- results = mcmc_with_warmup (
78
+ # Perform initial step.
79
+ results = DynamicHMC. mcmc_keep_warmup (
80
80
rng,
81
- FunctionLogDensity (
82
- length (vi[spl]),
83
- _lp
84
- ),
85
- N
81
+ DynamicHMCLogDensity (model, spl, vi),
82
+ 0 ;
83
+ initialization = (q = vi[spl],),
84
+ reporter = DynamicHMC. NoProgressReport (),
86
85
)
87
- draws = results. chain
86
+ steps = DynamicHMC. mcmc_steps (results. sampling_logdensity, results. final_warmup_state)
87
+ Q, _ = DynamicHMC. mcmc_next_step (steps, results. final_warmup_state. Q)
88
88
89
- # Compute first transition and state.
90
- draw = popfirst! (draws)
91
- vi[spl] = draw
92
- transition = Transition (vi)
93
- state = DynamicNUTSState (vi, draws)
89
+ # Update the variables.
90
+ vi[spl] = Q. q
91
+ DynamicPPL. setlogp! (vi, Q. ℓq)
94
92
95
- return transition, state
93
+ # If a Gibbs component, transform the values back to the constrained space.
94
+ if spl. selector. tag != = :default
95
+ DynamicPPL. invlink! (vi, spl)
96
+ end
97
+
98
+ # Create first sample and state.
99
+ sample = Transition (vi)
100
+ state = DynamicNUTSState (vi, Q, steps. H. κ, steps. ϵ)
101
+
102
+ return sample, state
96
103
end
97
104
98
105
function AbstractMCMC. step (
@@ -102,55 +109,38 @@ function AbstractMCMC.step(
102
109
state:: DynamicNUTSState ;
103
110
kwargs...
104
111
)
105
- # Extract VarInfo object .
112
+ # Compute next sample .
106
113
vi = state. vi
107
-
108
- # Pop the next draw off the vector.
109
- draw = popfirst! (state. draws)
110
- vi[spl] = draw
111
-
112
- # Compute next transition.
113
- transition = Transition (vi)
114
-
115
- return transition, state
116
- end
117
-
118
- # Disable the progress logging for DynamicHMC, since it has its own progress meter.
119
- function AbstractMCMC. sample (
120
- rng:: AbstractRNG ,
121
- model:: AbstractModel ,
122
- alg:: DynamicNUTS ,
123
- N:: Integer ;
124
- chain_type= MCMCChains. Chains,
125
- resume_from= nothing ,
126
- progress= PROGRESS[],
127
- kwargs...
128
- )
129
- if progress
130
- @warn " [HMC] Progress logging in Turing is disabled since DynamicHMC provides its own progress meter"
131
- end
132
- if resume_from === nothing
133
- return AbstractMCMC. sample (rng, model, Sampler (alg, model), N;
134
- chain_type= chain_type, progress= false , N= N, kwargs... )
114
+ ℓ = DynamicHMCLogDensity (model, spl, vi)
115
+ steps = DynamicHMC. mcmc_steps (
116
+ rng,
117
+ DynamicHMC. NUTS (),
118
+ state. metric,
119
+ ℓ,
120
+ state. stepsize,
121
+ )
122
+ Q = if spl. selector. tag != = :default
123
+ # When a Gibbs component, transform values to the unconstrained space
124
+ # and update the previous evaluation.
125
+ DynamicPPL. link! (vi, spl)
126
+ DynamicHMC. evaluate_ℓ (ℓ, vi[spl])
135
127
else
136
- return resume (resume_from, N; chain_type = chain_type, progress = false , N = N, kwargs ... )
128
+ state . cache
137
129
end
138
- end
130
+ newQ, _ = DynamicHMC . mcmc_next_step (steps, Q)
139
131
140
- function AbstractMCMC. sample (
141
- rng:: AbstractRNG ,
142
- model:: AbstractModel ,
143
- alg:: DynamicNUTS ,
144
- parallel:: AbstractMCMC.AbstractMCMCParallel ,
145
- N:: Integer ,
146
- n_chains:: Integer ;
147
- chain_type= MCMCChains. Chains,
148
- progress= PROGRESS[],
149
- kwargs...
150
- )
151
- if progress
152
- @warn " [HMC] Progress logging in Turing is disabled since DynamicHMC provides its own progress meter"
132
+ # Update the variables.
133
+ vi[spl] = newQ. q
134
+ DynamicPPL. setlogp! (vi, newQ. ℓq)
135
+
136
+ # If a Gibbs component, transform the values back to the constrained space.
137
+ if spl. selector. tag != = :default
138
+ DynamicPPL. invlink! (vi, spl)
153
139
end
154
- return AbstractMCMC. sample (rng, model, Sampler (alg, model), parallel, N, n_chains;
155
- chain_type= chain_type, progress= false , N= N, kwargs... )
140
+
141
+ # Create next sample and state.
142
+ sample = Transition (vi)
143
+ newstate = DynamicNUTSState (vi, newQ, state. metric, state. stepsize)
144
+
145
+ return sample, newstate
156
146
end
0 commit comments