@@ -41,82 +41,82 @@ function DynamicNUTS{AD}(space::Symbol...) where AD
41
41
DynamicNUTS {AD, space} ()
42
42
end
43
43
44
- mutable struct DynamicNUTSState{V<: VarInfo , D} <: AbstractSamplerState
44
+ struct DynamicNUTSState{V<: AbstractVarInfo ,D}
45
45
vi:: V
46
46
draws:: Vector{D}
47
47
end
48
48
49
49
DynamicPPL. getspace (:: DynamicNUTS{<:Any, space} ) where {space} = space
50
50
51
- function AbstractMCMC. sample_init! (
51
+ DynamicPPL. initialsampler (:: Sampler{<:DynamicNUTS} ) = SampleFromUniform ()
52
+
53
+ function DynamicPPL. initialstep (
52
54
rng:: AbstractRNG ,
53
55
model:: Model ,
54
56
spl:: Sampler{<:DynamicNUTS} ,
55
- N:: Integer ;
57
+ vi:: AbstractVarInfo ;
58
+ N:: Int ,
56
59
kwargs...
57
60
)
58
61
# Set up lp function.
59
62
function _lp (x)
60
- gradient_logp (x, spl . state . vi, model, spl)
63
+ gradient_logp (x, vi, model, spl)
61
64
end
62
65
63
- # Set the parameters to a starting value.
64
- initialize_parameters! (spl; kwargs... )
65
-
66
- model (spl. state. vi, SampleFromUniform ())
67
- link! (spl. state. vi, spl)
68
- l, dl = _lp (spl. state. vi[spl])
66
+ link! (vi, spl)
67
+ l, dl = _lp (vi[spl])
69
68
while ! isfinite (l) || ! isfinite (dl)
70
- model (spl . state . vi, SampleFromUniform ())
71
- link! (spl . state . vi, spl)
72
- l, dl = _lp (spl . state . vi[spl])
69
+ model (vi, SampleFromUniform ())
70
+ link! (vi, spl)
71
+ l, dl = _lp (vi[spl])
73
72
end
74
73
75
- if spl. selector. tag == :default && ! islinked (spl . state . vi, spl)
76
- link! (spl . state . vi, spl)
77
- model (spl . state . vi, spl)
74
+ if spl. selector. tag == :default && ! islinked (vi, spl)
75
+ link! (vi, spl)
76
+ model (vi, spl)
78
77
end
79
78
80
79
results = mcmc_with_warmup (
81
80
rng,
82
81
FunctionLogDensity (
83
- length (spl . state . vi[spl]),
82
+ length (vi[spl]),
84
83
_lp
85
84
),
86
85
N
87
86
)
87
+ draws = results. chain
88
88
89
- spl. state. draws = results. chain
89
+ # Compute first transition and state.
90
+ draw = popfirst! (draws)
91
+ vi[spl] = draw
92
+ transition = Transition (vi)
93
+ state = DynamicNUTSState (vi, draws)
94
+
95
+ return transition, state
90
96
end
91
97
92
- function AbstractMCMC. step! (
98
+ function AbstractMCMC. step (
93
99
rng:: AbstractRNG ,
94
100
model:: Model ,
95
101
spl:: Sampler{<:DynamicNUTS} ,
96
- N:: Integer ,
97
- transition;
102
+ state:: DynamicNUTSState ;
98
103
kwargs...
99
104
)
105
+ # Extract VarInfo object.
106
+ vi = state. vi
107
+
100
108
# Pop the next draw off the vector.
101
- draw = popfirst! (spl. state. draws)
102
- spl. state. vi[spl] = draw
103
- return Transition (spl)
104
- end
109
+ draw = popfirst! (state. draws)
110
+ vi[spl] = draw
105
111
106
- function Sampler (
107
- alg:: DynamicNUTS ,
108
- model:: Model ,
109
- s:: Selector = Selector ()
110
- )
111
- # Construct a state, using a default function.
112
- state = DynamicNUTSState (VarInfo (model), [])
112
+ # Compute next transition.
113
+ transition = Transition (vi)
113
114
114
- # Return a new sampler.
115
- return Sampler (alg, Dict {Symbol,Any} (), s, state)
115
+ return transition, state
116
116
end
117
117
118
- # Disable the progress logging for DynamicHMC, since it has its own progress meter.
119
- function AbstractMCMC. sample (
118
+ # Disable the progress logging for DynamicHMC, since it has its own progress meter.
119
+ function AbstractMCMC. sample (
120
120
rng:: AbstractRNG ,
121
121
model:: AbstractModel ,
122
122
alg:: DynamicNUTS ,
131
131
end
132
132
if resume_from === nothing
133
133
return AbstractMCMC. sample (rng, model, Sampler (alg, model), N;
134
- chain_type= chain_type, progress= false , kwargs... )
134
+ chain_type= chain_type, progress= false , N = N, kwargs... )
135
135
else
136
- return resume (resume_from, N; chain_type= chain_type, progress= false , kwargs... )
136
+ return resume (resume_from, N; chain_type= chain_type, progress= false , N = N, kwargs... )
137
137
end
138
138
end
139
139
@@ -152,5 +152,5 @@ function AbstractMCMC.sample(
152
152
@warn " [HMC] Progress logging in Turing is disabled since DynamicHMC provides its own progress meter"
153
153
end
154
154
return AbstractMCMC. sample (rng, model, Sampler (alg, model), parallel, N, n_chains;
155
- chain_type= chain_type, progress= false , kwargs... )
155
+ chain_type= chain_type, progress= false , N = N, kwargs... )
156
156
end
0 commit comments