Skip to content

Commit 2fd2a5a

Browse files
fonnesbeckColCarroll
authored andcommitted
Removed find_MAP warning (#3672)
* Removed find_MAP warning and added information prominently in the docstring * Applied black formatting to starting.py
1 parent e29aa0f commit 2fd2a5a

File tree

1 file changed

+55
-32
lines changed

1 file changed

+55
-32
lines changed

pymc3/tuning/starting.py

Lines changed: 55 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1-
'''
1+
"""
22
Created on Mar 12, 2011
33
44
@author: johnsalvatier
5-
'''
5+
"""
66
from scipy.optimize import minimize
77
import numpy as np
88
from numpy import isfinite, nan_to_num
@@ -18,15 +18,26 @@
1818
import warnings
1919
from inspect import getargspec
2020

21-
__all__ = ['find_MAP']
22-
23-
24-
def find_MAP(start=None, vars=None, method="L-BFGS-B",
25-
return_raw=False, include_transformed=True, progressbar=True, maxeval=5000, model=None,
26-
*args, **kwargs):
21+
__all__ = ["find_MAP"]
22+
23+
24+
def find_MAP(
25+
start=None,
26+
vars=None,
27+
method="L-BFGS-B",
28+
return_raw=False,
29+
include_transformed=True,
30+
progressbar=True,
31+
maxeval=5000,
32+
model=None,
33+
*args,
34+
**kwargs
35+
):
2736
"""
2837
Finds the local maximum a posteriori point given a model.
2938
39+
find_MAP should not be used to initialize the NUTS sampler. Simply call pymc3.sample() and it will automatically initialize NUTS in a better way.
40+
3041
Parameters
3142
----------
3243
start : `dict` of parameter values (Defaults to `model.test_point`)
@@ -53,24 +64,24 @@ def find_MAP(start=None, vars=None, method="L-BFGS-B",
5364
Notes
5465
-----
5566
Older code examples used find_MAP() to initialize the NUTS sampler,
56-
this turned out to be a rather inefficient method.
57-
Since then, we have greatly enhanced the initialization of NUTS and
67+
but this is not an effective way of choosing starting values for sampling.
68+
As a result, we have greatly enhanced the initialization of NUTS and
5869
wrapped it inside pymc3.sample() and you should thus avoid this method.
5970
"""
6071

61-
warnings.warn('find_MAP should not be used to initialize the NUTS sampler, simply call pymc3.sample() and it will automatically initialize NUTS in a better way.')
62-
6372
model = modelcontext(model)
6473
if start is None:
6574
start = model.test_point
6675
else:
6776
update_start_vals(start, model.test_point, model)
6877

6978
if not set(start.keys()).issubset(model.named_vars.keys()):
70-
extra_keys = ', '.join(set(start.keys()) - set(model.named_vars.keys()))
71-
valid_keys = ', '.join(model.named_vars.keys())
72-
raise KeyError('Some start parameters do not appear in the model!\n'
73-
'Valid keys are: {}, but {} was supplied'.format(valid_keys, extra_keys))
79+
extra_keys = ", ".join(set(start.keys()) - set(model.named_vars.keys()))
80+
valid_keys = ", ".join(model.named_vars.keys())
81+
raise KeyError(
82+
"Some start parameters do not appear in the model!\n"
83+
"Valid keys are: {}, but {} was supplied".format(valid_keys, extra_keys)
84+
)
7485

7586
if vars is None:
7687
vars = model.cont_vars
@@ -90,29 +101,37 @@ def find_MAP(start=None, vars=None, method="L-BFGS-B",
90101
compute_gradient = False
91102

92103
if disc_vars or not compute_gradient:
93-
pm._log.warning("Warning: gradient not available." +
94-
"(E.g. vars contains discrete variables). MAP " +
95-
"estimates may not be accurate for the default " +
96-
"parameters. Defaulting to non-gradient minimization " +
97-
"'Powell'.")
104+
pm._log.warning(
105+
"Warning: gradient not available."
106+
+ "(E.g. vars contains discrete variables). MAP "
107+
+ "estimates may not be accurate for the default "
108+
+ "parameters. Defaulting to non-gradient minimization "
109+
+ "'Powell'."
110+
)
98111
method = "Powell"
99112

100113
if "fmin" in kwargs:
101114
fmin = kwargs.pop("fmin")
102-
warnings.warn('In future versions, set the optimization algorithm with a string. '
103-
'For example, use `method="L-BFGS-B"` instead of '
104-
'`fmin=sp.optimize.fmin_l_bfgs_b"`.')
115+
warnings.warn(
116+
"In future versions, set the optimization algorithm with a string. "
117+
'For example, use `method="L-BFGS-B"` instead of '
118+
'`fmin=sp.optimize.fmin_l_bfgs_b"`.'
119+
)
105120

106121
cost_func = CostFuncWrapper(maxeval, progressbar, logp_func)
107122

108123
# Check to see if minimization function actually uses the gradient
109-
if 'fprime' in getargspec(fmin).args:
124+
if "fprime" in getargspec(fmin).args:
125+
110126
def grad_logp(point):
111127
return nan_to_num(-dlogp_func(point))
112-
opt_result = fmin(cost_func, bij.map(start), fprime=grad_logp, *args, **kwargs)
128+
129+
opt_result = fmin(
130+
cost_func, bij.map(start), fprime=grad_logp, *args, **kwargs
131+
)
113132
else:
114133
# Check to see if minimization function uses a starting value
115-
if 'x0' in getargspec(fmin).args:
134+
if "x0" in getargspec(fmin).args:
116135
opt_result = fmin(cost_func, bij.map(start), *args, **kwargs)
117136
else:
118137
opt_result = fmin(cost_func, *args, **kwargs)
@@ -129,7 +148,9 @@ def grad_logp(point):
129148
cost_func = CostFuncWrapper(maxeval, progressbar, logp_func)
130149

131150
try:
132-
opt_result = minimize(cost_func, x0, method=method, jac=compute_gradient, *args, **kwargs)
151+
opt_result = minimize(
152+
cost_func, x0, method=method, jac=compute_gradient, *args, **kwargs
153+
)
133154
mx0 = opt_result["x"] # r -> opt_result
134155
cost_func.progress.total = cost_func.progress.n + 1
135156
cost_func.progress.update()
@@ -142,7 +163,9 @@ def grad_logp(point):
142163
cost_func.progress.close()
143164

144165
vars = get_default_varnames(model.unobserved_RVs, include_transformed)
145-
mx = {var.name: value for var, value in zip(vars, model.fastfn(vars)(bij.rmap(mx0)))}
166+
mx = {
167+
var.name: value for var, value in zip(vars, model.fastfn(vars)(bij.rmap(mx0)))
168+
}
146169

147170
if return_raw:
148171
return mx, opt_result
@@ -171,11 +194,11 @@ def __init__(self, maxeval=5000, progressbar=True, logp_func=None, dlogp_func=No
171194
self.logp_func = logp_func
172195
if dlogp_func is None:
173196
self.use_gradient = False
174-
self.desc = 'logp = {:,.5g}'
197+
self.desc = "logp = {:,.5g}"
175198
else:
176199
self.dlogp_func = dlogp_func
177200
self.use_gradient = True
178-
self.desc = 'logp = {:,.5g}, ||grad|| = {:,.5g}'
201+
self.desc = "logp = {:,.5g}, ||grad|| = {:,.5g}"
179202
self.previous_x = None
180203
self.progress = tqdm(total=maxeval, disable=not progressbar)
181204
self.progress.n = 0
@@ -187,7 +210,7 @@ def __call__(self, x):
187210
neg_grad = self.dlogp_func(pm.floatX(x))
188211
if np.all(np.isfinite(neg_grad)):
189212
self.previous_x = x
190-
grad = nan_to_num(-1.0*neg_grad)
213+
grad = nan_to_num(-1.0 * neg_grad)
191214
grad = grad.astype(np.float64)
192215
else:
193216
self.previous_x = x

0 commit comments

Comments
 (0)