Skip to content

Commit 0bdff96

Browse files
authored
Do not compute gradient in find_MAP() if it's not required. (#1551)
* ENH Do not compute gradient in find_MAP() if it's not required. Closes #639. * MAINT Remove unused Poisson import. * BUG Add missing allfinite() call. Refactor model to models.py * TST Missed import.
1 parent 818e4cd commit 0bdff96

File tree

3 files changed

+53
-17
lines changed

3 files changed

+53
-17
lines changed

pymc3/tests/models.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import pymc3 as pm
44
from itertools import product
55
import theano.tensor as tt
6+
from theano.compile.ops import as_op
67

78

89
def simple_model():
@@ -34,6 +35,20 @@ def multidimensional_model():
3435
return model.test_point, model, (mu, tau ** -1)
3536

3637

38+
def simple_arbitrary_det():
39+
@as_op(itypes=[tt.dscalar], otypes=[tt.dscalar])
40+
def arbitrary_det(value):
41+
return value
42+
43+
with Model() as model:
44+
a = Normal('a')
45+
b = arbitrary_det(a)
46+
c = Normal('obs', mu=b.astype('float64'),
47+
observed=np.array([1, 3, 5]))
48+
49+
return model.test_point, model
50+
51+
3752
def simple_init():
3853
start, model, moments = simple_model()
3954
step = Metropolis(model.vars, np.diag([1.]), model=model)

pymc3/tests/test_starting.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,7 @@
22
import numpy as np
33
from pymc3.tuning import starting
44
from pymc3 import Model, Uniform, Normal, Beta, Binomial, find_MAP, Point
5-
from .models import simple_model, non_normal, exponential_beta
6-
5+
from .models import simple_model, non_normal, exponential_beta, simple_arbitrary_det
76

87
def test_accuracy_normal():
98
_, model, (mu, _) = simple_model()
@@ -53,6 +52,12 @@ def test_find_MAP_discrete():
5352
assert map_est2['ss'] == 14
5453

5554

55+
def test_find_MAP_no_gradient():
56+
_, model = simple_arbitrary_det()
57+
with model:
58+
find_MAP()
59+
60+
5661
def test_find_MAP():
5762
tol = 2.0**-11 # 16 bit machine epsilon, a low bar
5863
data = np.random.randn(100)

pymc3/tuning/starting.py

Lines changed: 31 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@
1717
__all__ = ['find_MAP']
1818

1919

20-
def find_MAP(start=None, vars=None, fmin=None, return_raw=False,
21-
model=None, *args, **kwargs):
20+
def find_MAP(start=None, vars=None, fmin=None,
21+
return_raw=False,model=None, *args, **kwargs):
2222
"""
2323
Sets state to the local maximum a posteriori point given a model.
2424
Current default of fmin_Hessian does not deal well with optimizing close
@@ -55,8 +55,15 @@ def find_MAP(start=None, vars=None, fmin=None, return_raw=False,
5555

5656
disc_vars = list(typefilter(vars, discrete_types))
5757

58-
if disc_vars:
59-
pm._log.warning("Warning: vars contains discrete variables. MAP " +
58+
try:
59+
model.fastdlogp(vars)
60+
gradient_avail = True
61+
except AttributeError:
62+
gradient_avail = False
63+
64+
if disc_vars or not gradient_avail :
65+
pm._log.warning("Warning: gradient not available." +
66+
"(E.g. vars contains discrete variables). MAP " +
6067
"estimates may not be accurate for the default " +
6168
"parameters. Defaulting to non-gradient minimization " +
6269
"fmin_powell.")
@@ -74,19 +81,21 @@ def find_MAP(start=None, vars=None, fmin=None, return_raw=False,
7481
bij = DictToArrayBijection(ArrayOrdering(vars), start)
7582

7683
logp = bij.mapf(model.fastlogp)
77-
dlogp = bij.mapf(model.fastdlogp(vars))
78-
7984
def logp_o(point):
8085
return nan_to_high(-logp(point))
8186

82-
def grad_logp_o(point):
83-
return nan_to_num(-dlogp(point))
84-
8587
# Check to see if minimization function actually uses the gradient
8688
if 'fprime' in getargspec(fmin).args:
89+
dlogp = bij.mapf(model.fastdlogp(vars))
90+
def grad_logp_o(point):
91+
return nan_to_num(-dlogp(point))
92+
8793
r = fmin(logp_o, bij.map(
8894
start), fprime=grad_logp_o, *args, **kwargs)
95+
compute_gradient = True
8996
else:
97+
compute_gradient = False
98+
9099
# Check to see if minimization function uses a starting value
91100
if 'x0' in getargspec(fmin).args:
92101
r = fmin(logp_o, bij.map(start), *args, **kwargs)
@@ -100,17 +109,24 @@ def grad_logp_o(point):
100109

101110
mx = bij.rmap(mx0)
102111

103-
if (not allfinite(mx0) or
104-
not allfinite(model.logp(mx)) or
105-
not allfinite(model.dlogp()(mx))):
112+
allfinite_mx0 = allfinite(mx0)
113+
allfinite_logp = allfinite(model.logp(mx))
114+
if compute_gradient:
115+
allfinite_dlogp = allfinite(model.dlogp()(mx))
116+
else:
117+
allfinite_dlogp = True
118+
119+
if (not allfinite_mx0 or
120+
not allfinite_logp or
121+
not allfinite_dlogp):
106122

107123
messages = []
108124
for var in vars:
109-
110125
vals = {
111126
"value": mx[var.name],
112-
"logp": var.logp(mx),
113-
"dlogp": var.dlogp()(mx)}
127+
"logp": var.logp(mx)}
128+
if compute_gradient:
129+
vals["dlogp"] = var.dlogp()(mx)
114130

115131
def message(name, values):
116132
if np.size(values) < 10:

0 commit comments

Comments
 (0)