1
1
module TuringOptimExt
2
2
3
- if isdefined (Base, :get_extension )
4
- using Turing: Turing
5
- import Turing: DynamicPPL, NamedArrays, Accessors, Optimisation
6
- using Optim: Optim
7
- else
8
- import .. Turing
9
- import .. Turing: DynamicPPL, NamedArrays, Accessors, Optimisation
10
- import .. Optim
11
- end
3
+ using Turing: Turing
4
+ import Turing: DynamicPPL, NamedArrays, Accessors, Optimisation
5
+ using Optim: Optim
12
6
13
7
# ###################
14
8
# Optim.jl methods #
@@ -42,7 +36,7 @@ function Optim.optimize(
42
36
)
43
37
ctx = Optimisation. OptimizationContext (DynamicPPL. LikelihoodContext ())
44
38
f = Optimisation. OptimLogDensity (model, ctx)
45
- init_vals = DynamicPPL. getparams (f)
39
+ init_vals = DynamicPPL. getparams (f. ldf )
46
40
optimizer = Optim. LBFGS ()
47
41
return _mle_optimize (model, init_vals, optimizer, options; kwargs... )
48
42
end
@@ -65,7 +59,7 @@ function Optim.optimize(
65
59
)
66
60
ctx = Optimisation. OptimizationContext (DynamicPPL. LikelihoodContext ())
67
61
f = Optimisation. OptimLogDensity (model, ctx)
68
- init_vals = DynamicPPL. getparams (f)
62
+ init_vals = DynamicPPL. getparams (f. ldf )
69
63
return _mle_optimize (model, init_vals, optimizer, options; kwargs... )
70
64
end
71
65
function Optim. optimize (
81
75
82
76
function _mle_optimize (model:: DynamicPPL.Model , args... ; kwargs... )
83
77
ctx = Optimisation. OptimizationContext (DynamicPPL. LikelihoodContext ())
84
- return _optimize (model, Optimisation. OptimLogDensity (model, ctx), args... ; kwargs... )
78
+ return _optimize (Optimisation. OptimLogDensity (model, ctx), args... ; kwargs... )
85
79
end
86
80
87
81
"""
@@ -112,7 +106,7 @@ function Optim.optimize(
112
106
)
113
107
ctx = Optimisation. OptimizationContext (DynamicPPL. DefaultContext ())
114
108
f = Optimisation. OptimLogDensity (model, ctx)
115
- init_vals = DynamicPPL. getparams (f)
109
+ init_vals = DynamicPPL. getparams (f. ldf )
116
110
optimizer = Optim. LBFGS ()
117
111
return _map_optimize (model, init_vals, optimizer, options; kwargs... )
118
112
end
@@ -135,7 +129,7 @@ function Optim.optimize(
135
129
)
136
130
ctx = Optimisation. OptimizationContext (DynamicPPL. DefaultContext ())
137
131
f = Optimisation. OptimLogDensity (model, ctx)
138
- init_vals = DynamicPPL. getparams (f)
132
+ init_vals = DynamicPPL. getparams (f. ldf )
139
133
return _map_optimize (model, init_vals, optimizer, options; kwargs... )
140
134
end
141
135
function Optim. optimize (
@@ -151,28 +145,29 @@ end
151
145
152
146
function _map_optimize (model:: DynamicPPL.Model , args... ; kwargs... )
153
147
ctx = Optimisation. OptimizationContext (DynamicPPL. DefaultContext ())
154
- return _optimize (model, Optimisation. OptimLogDensity (model, ctx), args... ; kwargs... )
148
+ return _optimize (Optimisation. OptimLogDensity (model, ctx), args... ; kwargs... )
155
149
end
156
-
157
150
"""
158
- _optimize(model::Model, f::OptimLogDensity, optimizer=Optim.LBFGS(), args...; kwargs...)
151
+ _optimize(f::OptimLogDensity, optimizer=Optim.LBFGS(), args...; kwargs...)
159
152
160
153
Estimate a mode, i.e., compute a MLE or MAP estimate.
161
154
"""
162
155
function _optimize (
163
- model:: DynamicPPL.Model ,
164
156
f:: Optimisation.OptimLogDensity ,
165
- init_vals:: AbstractArray = DynamicPPL. getparams (f),
157
+ init_vals:: AbstractArray = DynamicPPL. getparams (f. ldf ),
166
158
optimizer:: Optim.AbstractOptimizer = Optim. LBFGS (),
167
159
options:: Optim.Options = Optim. Options (),
168
160
args... ;
169
161
kwargs... ,
170
162
)
171
163
# Convert the initial values, since it is assumed that users provide them
172
164
# in the constrained space.
173
- f = Accessors. @set f. varinfo = DynamicPPL. unflatten (f. varinfo, init_vals)
174
- f = Accessors. @set f. varinfo = DynamicPPL. link (f. varinfo, model)
175
- init_vals = DynamicPPL. getparams (f)
165
+ # TODO (penelopeysm): As with in src/optimisation/Optimisation.jl, unclear
166
+ # whether initialisation is really necessary at all
167
+ vi = DynamicPPL. unflatten (f. ldf. varinfo, init_vals)
168
+ vi = DynamicPPL. link (vi, f. ldf. model)
169
+ f = Optimisation. OptimLogDensity (f. ldf. model, vi, f. ldf. context; adtype= f. ldf. adtype)
170
+ init_vals = DynamicPPL. getparams (f. ldf)
176
171
177
172
# Optimize!
178
173
M = Optim. optimize (Optim. only_fg! (f), init_vals, optimizer, options, args... ; kwargs... )
@@ -186,12 +181,16 @@ function _optimize(
186
181
end
187
182
188
183
# Get the optimum in unconstrained space. `getparams` does the invlinking.
189
- f = Accessors. @set f. varinfo = DynamicPPL. unflatten (f. varinfo, M. minimizer)
190
- vns_vals_iter = Turing. Inference. getparams (model, f. varinfo)
184
+ vi = f. ldf. varinfo
185
+ vi_optimum = DynamicPPL. unflatten (vi, M. minimizer)
186
+ logdensity_optimum = Optimisation. OptimLogDensity (
187
+ f. ldf. model, vi_optimum, f. ldf. context
188
+ )
189
+ vns_vals_iter = Turing. Inference. getparams (f. ldf. model, vi_optimum)
191
190
varnames = map (Symbol ∘ first, vns_vals_iter)
192
191
vals = map (last, vns_vals_iter)
193
192
vmat = NamedArrays. NamedArray (vals, varnames)
194
- return Optimisation. ModeResult (vmat, M, - M. minimum, f )
193
+ return Optimisation. ModeResult (vmat, M, - M. minimum, logdensity_optimum )
195
194
end
196
195
197
196
end # module
0 commit comments