@@ -34,8 +34,7 @@ function Optim.optimize(
34
34
options:: Optim.Options = Optim. Options ();
35
35
kwargs... ,
36
36
)
37
- vi = DynamicPPL. setaccs!! (VarInfo (model), (DynamicPPL. LogLikelihoodAccumulator (),))
38
- f = Optimisation. OptimLogDensity (model, vi)
37
+ f = Optimisation. OptimLogDensity (model, DynamicPPL. getloglikelihood)
39
38
init_vals = DynamicPPL. getparams (f. ldf)
40
39
optimizer = Optim. LBFGS ()
41
40
return _mle_optimize (model, init_vals, optimizer, options; kwargs... )
@@ -57,8 +56,7 @@ function Optim.optimize(
57
56
options:: Optim.Options = Optim. Options ();
58
57
kwargs... ,
59
58
)
60
- vi = DynamicPPL. setaccs!! (VarInfo (model), (DynamicPPL. LogLikelihoodAccumulator (),))
61
- f = Optimisation. OptimLogDensity (model, vi)
59
+ f = Optimisation. OptimLogDensity (model, DynamicPPL. getloglikelihood)
62
60
init_vals = DynamicPPL. getparams (f. ldf)
63
61
return _mle_optimize (model, init_vals, optimizer, options; kwargs... )
64
62
end
@@ -74,8 +72,7 @@ function Optim.optimize(
74
72
end
75
73
76
74
function _mle_optimize (model:: DynamicPPL.Model , args... ; kwargs... )
77
- vi = DynamicPPL. setaccs!! (VarInfo (model), (DynamicPPL. LogLikelihoodAccumulator (),))
78
- f = Optimisation. OptimLogDensity (model, vi)
75
+ f = Optimisation. OptimLogDensity (model, DynamicPPL. getloglikelihood)
79
76
return _optimize (f, args... ; kwargs... )
80
77
end
81
78
@@ -105,8 +102,7 @@ function Optim.optimize(
105
102
options:: Optim.Options = Optim. Options ();
106
103
kwargs... ,
107
104
)
108
- vi = DynamicPPL. setaccs!! (VarInfo (model), (LogPriorWithoutJacobianAccumulator (), DynamicPPL. LogLikelihoodAccumulator (),))
109
- f = Optimisation. OptimLogDensity (model, vi)
105
+ f = Optimisation. OptimLogDensity (model, Optimisation. getlogjoint_without_jacobian)
110
106
init_vals = DynamicPPL. getparams (f. ldf)
111
107
optimizer = Optim. LBFGS ()
112
108
return _map_optimize (model, init_vals, optimizer, options; kwargs... )
@@ -128,8 +124,7 @@ function Optim.optimize(
128
124
options:: Optim.Options = Optim. Options ();
129
125
kwargs... ,
130
126
)
131
- vi = DynamicPPL. setaccs!! (VarInfo (model), (LogPriorWithoutJacobianAccumulator (), DynamicPPL. LogLikelihoodAccumulator (),))
132
- f = Optimisation. OptimLogDensity (model, vi)
127
+ f = Optimisation. OptimLogDensity (model, Optimisation. getlogjoint_without_jacobian)
133
128
init_vals = DynamicPPL. getparams (f. ldf)
134
129
return _map_optimize (model, init_vals, optimizer, options; kwargs... )
135
130
end
@@ -145,8 +140,7 @@ function Optim.optimize(
145
140
end
146
141
147
142
function _map_optimize (model:: DynamicPPL.Model , args... ; kwargs... )
148
- vi = DynamicPPL. setaccs!! (VarInfo (model), (LogPriorWithoutJacobianAccumulator (), DynamicPPL. LogLikelihoodAccumulator (),))
149
- f = Optimisation. OptimLogDensity (model, vi)
143
+ f = Optimisation. OptimLogDensity (model, Optimisation. getlogjoint_without_jacobian)
150
144
return _optimize (f, args... ; kwargs... )
151
145
end
152
146
@@ -169,7 +163,9 @@ function _optimize(
169
163
# whether initialisation is really necessary at all
170
164
vi = DynamicPPL. unflatten (f. ldf. varinfo, init_vals)
171
165
vi = DynamicPPL. link (vi, f. ldf. model)
172
- f = Optimisation. OptimLogDensity (f. ldf. model, vi; adtype= f. ldf. adtype)
166
+ f = Optimisation. OptimLogDensity (
167
+ f. ldf. model, f. ldf. getlogdensity, vi; adtype= f. ldf. adtype
168
+ )
173
169
init_vals = DynamicPPL. getparams (f. ldf)
174
170
175
171
# Optimize!
@@ -186,7 +182,9 @@ function _optimize(
186
182
# Get the optimum in unconstrained space. `getparams` does the invlinking.
187
183
vi = f. ldf. varinfo
188
184
vi_optimum = DynamicPPL. unflatten (vi, M. minimizer)
189
- logdensity_optimum = Optimisation. OptimLogDensity (f. ldf. model, vi_optimum; adtype= f. ldf. adtype)
185
+ logdensity_optimum = Optimisation. OptimLogDensity (
186
+ f. ldf. model, f. ldf. getlogdensity, vi_optimum; adtype= f. ldf. adtype
187
+ )
190
188
vns_vals_iter = Turing. Inference. getparams (f. ldf. model, vi_optimum)
191
189
varnames = map (Symbol ∘ first, vns_vals_iter)
192
190
vals = map (last, vns_vals_iter)
0 commit comments