Skip to content

Commit e03c227

Browse files
authored
find_MAP informative error (#4423), Closes #3143
* Raise ValueError when there are no continuous variables Remove deprecated block of code * pre-commit * Remove unused imports
1 parent 37ca5ea commit e03c227

File tree

1 file changed

+25
-59
lines changed

1 file changed

+25
-59
lines changed

pymc3/tuning/starting.py

Lines changed: 25 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,6 @@
1717
1818
@author: johnsalvatier
1919
"""
20-
import warnings
21-
22-
from inspect import getargspec
23-
2420
import numpy as np
2521
import theano.gradient as tg
2622

@@ -92,19 +88,21 @@ def find_MAP(
9288
wrapped it inside pymc3.sample() and you should thus avoid this method.
9389
"""
9490
model = modelcontext(model)
95-
if start is None:
96-
start = model.test_point
97-
else:
98-
update_start_vals(start, model.test_point, model)
99-
100-
check_start_vals(start, model)
10191

10292
if vars is None:
10393
vars = model.cont_vars
94+
if not vars:
95+
raise ValueError("Model has no unobserved continuous variables.")
10496
vars = inputvars(vars)
10597
disc_vars = list(typefilter(vars, discrete_types))
10698
allinmodel(vars, model)
10799

100+
if start is None:
101+
start = model.test_point
102+
else:
103+
update_start_vals(start, model.test_point, model)
104+
check_start_vals(start, model)
105+
108106
start = Point(start, model=model)
109107
bij = DictToArrayBijection(ArrayOrdering(vars), start)
110108
logp_func = bij.mapf(model.fastlogp_nojac)
@@ -126,57 +124,25 @@ def find_MAP(
126124
)
127125
method = "Powell"
128126

129-
if "fmin" in kwargs:
130-
fmin = kwargs.pop("fmin")
131-
warnings.warn(
132-
"In future versions, set the optimization algorithm with a string. "
133-
'For example, use `method="L-BFGS-B"` instead of '
134-
'`fmin=sp.optimize.fmin_l_bfgs_b"`.'
135-
)
136-
127+
if compute_gradient:
128+
cost_func = CostFuncWrapper(maxeval, progressbar, logp_func, dlogp_func)
129+
else:
137130
cost_func = CostFuncWrapper(maxeval, progressbar, logp_func)
138131

139-
# Check to see if minimization function actually uses the gradient
140-
if "fprime" in getargspec(fmin).args:
141-
142-
def grad_logp(point):
143-
return nan_to_num(-dlogp_func(point))
144-
145-
opt_result = fmin(cost_func, x0, fprime=grad_logp, *args, **kwargs)
146-
else:
147-
# Check to see if minimization function uses a starting value
148-
if "x0" in getargspec(fmin).args:
149-
opt_result = fmin(cost_func, x0, *args, **kwargs)
150-
else:
151-
opt_result = fmin(cost_func, *args, **kwargs)
152-
153-
if isinstance(opt_result, tuple):
154-
mx0 = opt_result[0]
155-
else:
156-
mx0 = opt_result
157-
else:
158-
# remove 'if' part, keep just this 'else' block after version change
159-
if compute_gradient:
160-
cost_func = CostFuncWrapper(maxeval, progressbar, logp_func, dlogp_func)
161-
else:
162-
cost_func = CostFuncWrapper(maxeval, progressbar, logp_func)
163-
164-
try:
165-
opt_result = minimize(
166-
cost_func, x0, method=method, jac=compute_gradient, *args, **kwargs
167-
)
168-
mx0 = opt_result["x"] # r -> opt_result
169-
except (KeyboardInterrupt, StopIteration) as e:
170-
mx0, opt_result = cost_func.previous_x, None
171-
if isinstance(e, StopIteration):
172-
pm._log.info(e)
173-
finally:
174-
last_v = cost_func.n_eval
175-
if progressbar:
176-
assert isinstance(cost_func.progress, ProgressBar)
177-
cost_func.progress.total = last_v
178-
cost_func.progress.update(last_v)
179-
print()
132+
try:
133+
opt_result = minimize(cost_func, x0, method=method, jac=compute_gradient, *args, **kwargs)
134+
mx0 = opt_result["x"] # r -> opt_result
135+
except (KeyboardInterrupt, StopIteration) as e:
136+
mx0, opt_result = cost_func.previous_x, None
137+
if isinstance(e, StopIteration):
138+
pm._log.info(e)
139+
finally:
140+
last_v = cost_func.n_eval
141+
if progressbar:
142+
assert isinstance(cost_func.progress, ProgressBar)
143+
cost_func.progress.total = last_v
144+
cost_func.progress.update(last_v)
145+
print()
180146

181147
vars = get_default_varnames(model.unobserved_RVs, include_transformed)
182148
mx = {var.name: value for var, value in zip(vars, model.fastfn(vars)(bij.rmap(mx0)))}

0 commit comments

Comments
 (0)