1
- '''
1
+ """
2
2
Created on Mar 12, 2011
3
3
4
4
@author: johnsalvatier
5
- '''
5
+ """
6
6
from scipy .optimize import minimize
7
7
import numpy as np
8
8
from numpy import isfinite , nan_to_num
18
18
import warnings
19
19
from inspect import getargspec
20
20
21
- __all__ = ['find_MAP' ]
22
-
23
-
24
- def find_MAP (start = None , vars = None , method = "L-BFGS-B" ,
25
- return_raw = False , include_transformed = True , progressbar = True , maxeval = 5000 , model = None ,
26
- * args , ** kwargs ):
21
+ __all__ = ["find_MAP" ]
22
+
23
+
24
+ def find_MAP (
25
+ start = None ,
26
+ vars = None ,
27
+ method = "L-BFGS-B" ,
28
+ return_raw = False ,
29
+ include_transformed = True ,
30
+ progressbar = True ,
31
+ maxeval = 5000 ,
32
+ model = None ,
33
+ * args ,
34
+ ** kwargs
35
+ ):
27
36
"""
28
37
Finds the local maximum a posteriori point given a model.
29
38
39
+ find_MAP should not be used to initialize the NUTS sampler. Simply call pymc3.sample() and it will automatically initialize NUTS in a better way.
40
+
30
41
Parameters
31
42
----------
32
43
start : `dict` of parameter values (Defaults to `model.test_point`)
@@ -53,24 +64,24 @@ def find_MAP(start=None, vars=None, method="L-BFGS-B",
53
64
Notes
54
65
-----
55
66
Older code examples used find_MAP() to initialize the NUTS sampler,
56
- this turned out to be a rather inefficient method .
57
- Since then , we have greatly enhanced the initialization of NUTS and
67
+ but this is not an effective way of choosing starting values for sampling .
68
+ As a result , we have greatly enhanced the initialization of NUTS and
58
69
wrapped it inside pymc3.sample() and you should thus avoid this method.
59
70
"""
60
71
61
- warnings .warn ('find_MAP should not be used to initialize the NUTS sampler, simply call pymc3.sample() and it will automatically initialize NUTS in a better way.' )
62
-
63
72
model = modelcontext (model )
64
73
if start is None :
65
74
start = model .test_point
66
75
else :
67
76
update_start_vals (start , model .test_point , model )
68
77
69
78
if not set (start .keys ()).issubset (model .named_vars .keys ()):
70
- extra_keys = ', ' .join (set (start .keys ()) - set (model .named_vars .keys ()))
71
- valid_keys = ', ' .join (model .named_vars .keys ())
72
- raise KeyError ('Some start parameters do not appear in the model!\n '
73
- 'Valid keys are: {}, but {} was supplied' .format (valid_keys , extra_keys ))
79
+ extra_keys = ", " .join (set (start .keys ()) - set (model .named_vars .keys ()))
80
+ valid_keys = ", " .join (model .named_vars .keys ())
81
+ raise KeyError (
82
+ "Some start parameters do not appear in the model!\n "
83
+ "Valid keys are: {}, but {} was supplied" .format (valid_keys , extra_keys )
84
+ )
74
85
75
86
if vars is None :
76
87
vars = model .cont_vars
@@ -90,29 +101,37 @@ def find_MAP(start=None, vars=None, method="L-BFGS-B",
90
101
compute_gradient = False
91
102
92
103
if disc_vars or not compute_gradient :
93
- pm ._log .warning ("Warning: gradient not available." +
94
- "(E.g. vars contains discrete variables). MAP " +
95
- "estimates may not be accurate for the default " +
96
- "parameters. Defaulting to non-gradient minimization " +
97
- "'Powell'." )
104
+ pm ._log .warning (
105
+ "Warning: gradient not available."
106
+ + "(E.g. vars contains discrete variables). MAP "
107
+ + "estimates may not be accurate for the default "
108
+ + "parameters. Defaulting to non-gradient minimization "
109
+ + "'Powell'."
110
+ )
98
111
method = "Powell"
99
112
100
113
if "fmin" in kwargs :
101
114
fmin = kwargs .pop ("fmin" )
102
- warnings .warn ('In future versions, set the optimization algorithm with a string. '
103
- 'For example, use `method="L-BFGS-B"` instead of '
104
- '`fmin=sp.optimize.fmin_l_bfgs_b"`.' )
115
+ warnings .warn (
116
+ "In future versions, set the optimization algorithm with a string. "
117
+ 'For example, use `method="L-BFGS-B"` instead of '
118
+ '`fmin=sp.optimize.fmin_l_bfgs_b"`.'
119
+ )
105
120
106
121
cost_func = CostFuncWrapper (maxeval , progressbar , logp_func )
107
122
108
123
# Check to see if minimization function actually uses the gradient
109
- if 'fprime' in getargspec (fmin ).args :
124
+ if "fprime" in getargspec (fmin ).args :
125
+
110
126
def grad_logp (point ):
111
127
return nan_to_num (- dlogp_func (point ))
112
- opt_result = fmin (cost_func , bij .map (start ), fprime = grad_logp , * args , ** kwargs )
128
+
129
+ opt_result = fmin (
130
+ cost_func , bij .map (start ), fprime = grad_logp , * args , ** kwargs
131
+ )
113
132
else :
114
133
# Check to see if minimization function uses a starting value
115
- if 'x0' in getargspec (fmin ).args :
134
+ if "x0" in getargspec (fmin ).args :
116
135
opt_result = fmin (cost_func , bij .map (start ), * args , ** kwargs )
117
136
else :
118
137
opt_result = fmin (cost_func , * args , ** kwargs )
@@ -129,7 +148,9 @@ def grad_logp(point):
129
148
cost_func = CostFuncWrapper (maxeval , progressbar , logp_func )
130
149
131
150
try :
132
- opt_result = minimize (cost_func , x0 , method = method , jac = compute_gradient , * args , ** kwargs )
151
+ opt_result = minimize (
152
+ cost_func , x0 , method = method , jac = compute_gradient , * args , ** kwargs
153
+ )
133
154
mx0 = opt_result ["x" ] # r -> opt_result
134
155
cost_func .progress .total = cost_func .progress .n + 1
135
156
cost_func .progress .update ()
@@ -142,7 +163,9 @@ def grad_logp(point):
142
163
cost_func .progress .close ()
143
164
144
165
vars = get_default_varnames (model .unobserved_RVs , include_transformed )
145
- mx = {var .name : value for var , value in zip (vars , model .fastfn (vars )(bij .rmap (mx0 )))}
166
+ mx = {
167
+ var .name : value for var , value in zip (vars , model .fastfn (vars )(bij .rmap (mx0 )))
168
+ }
146
169
147
170
if return_raw :
148
171
return mx , opt_result
@@ -171,11 +194,11 @@ def __init__(self, maxeval=5000, progressbar=True, logp_func=None, dlogp_func=No
171
194
self .logp_func = logp_func
172
195
if dlogp_func is None :
173
196
self .use_gradient = False
174
- self .desc = ' logp = {:,.5g}'
197
+ self .desc = " logp = {:,.5g}"
175
198
else :
176
199
self .dlogp_func = dlogp_func
177
200
self .use_gradient = True
178
- self .desc = ' logp = {:,.5g}, ||grad|| = {:,.5g}'
201
+ self .desc = " logp = {:,.5g}, ||grad|| = {:,.5g}"
179
202
self .previous_x = None
180
203
self .progress = tqdm (total = maxeval , disable = not progressbar )
181
204
self .progress .n = 0
@@ -187,7 +210,7 @@ def __call__(self, x):
187
210
neg_grad = self .dlogp_func (pm .floatX (x ))
188
211
if np .all (np .isfinite (neg_grad )):
189
212
self .previous_x = x
190
- grad = nan_to_num (- 1.0 * neg_grad )
213
+ grad = nan_to_num (- 1.0 * neg_grad )
191
214
grad = grad .astype (np .float64 )
192
215
else :
193
216
self .previous_x = x
0 commit comments