17
17
__all__ = ['find_MAP' ]
18
18
19
19
20
- def find_MAP (start = None , vars = None , fmin = None , return_raw = False ,
21
- model = None , * args , ** kwargs ):
20
+ def find_MAP (start = None , vars = None , fmin = None ,
21
+ return_raw = False , model = None , * args , ** kwargs ):
22
22
"""
23
23
Sets state to the local maximum a posteriori point given a model.
24
24
Current default of fmin_Hessian does not deal well with optimizing close
@@ -55,8 +55,15 @@ def find_MAP(start=None, vars=None, fmin=None, return_raw=False,
55
55
56
56
disc_vars = list (typefilter (vars , discrete_types ))
57
57
58
- if disc_vars :
59
- pm ._log .warning ("Warning: vars contains discrete variables. MAP " +
58
+ try :
59
+ model .fastdlogp (vars )
60
+ gradient_avail = True
61
+ except AttributeError :
62
+ gradient_avail = False
63
+
64
+ if disc_vars or not gradient_avail :
65
+ pm ._log .warning ("Warning: gradient not available." +
66
+ "(E.g. vars contains discrete variables). MAP " +
60
67
"estimates may not be accurate for the default " +
61
68
"parameters. Defaulting to non-gradient minimization " +
62
69
"fmin_powell." )
@@ -74,19 +81,21 @@ def find_MAP(start=None, vars=None, fmin=None, return_raw=False,
74
81
bij = DictToArrayBijection (ArrayOrdering (vars ), start )
75
82
76
83
logp = bij .mapf (model .fastlogp )
77
- dlogp = bij .mapf (model .fastdlogp (vars ))
78
-
79
84
def logp_o (point ):
80
85
return nan_to_high (- logp (point ))
81
86
82
- def grad_logp_o (point ):
83
- return nan_to_num (- dlogp (point ))
84
-
85
87
# Check to see if minimization function actually uses the gradient
86
88
if 'fprime' in getargspec (fmin ).args :
89
+ dlogp = bij .mapf (model .fastdlogp (vars ))
90
+ def grad_logp_o (point ):
91
+ return nan_to_num (- dlogp (point ))
92
+
87
93
r = fmin (logp_o , bij .map (
88
94
start ), fprime = grad_logp_o , * args , ** kwargs )
95
+ compute_gradient = True
89
96
else :
97
+ compute_gradient = False
98
+
90
99
# Check to see if minimization function uses a starting value
91
100
if 'x0' in getargspec (fmin ).args :
92
101
r = fmin (logp_o , bij .map (start ), * args , ** kwargs )
@@ -100,17 +109,24 @@ def grad_logp_o(point):
100
109
101
110
mx = bij .rmap (mx0 )
102
111
103
- if (not allfinite (mx0 ) or
104
- not allfinite (model .logp (mx )) or
105
- not allfinite (model .dlogp ()(mx ))):
112
+ allfinite_mx0 = allfinite (mx0 )
113
+ allfinite_logp = allfinite (model .logp (mx ))
114
+ if compute_gradient :
115
+ allfinite_dlogp = allfinite (model .dlogp ()(mx ))
116
+ else :
117
+ allfinite_dlogp = True
118
+
119
+ if (not allfinite_mx0 or
120
+ not allfinite_logp or
121
+ not allfinite_dlogp ):
106
122
107
123
messages = []
108
124
for var in vars :
109
-
110
125
vals = {
111
126
"value" : mx [var .name ],
112
- "logp" : var .logp (mx ),
113
- "dlogp" : var .dlogp ()(mx )}
127
+ "logp" : var .logp (mx )}
128
+ if compute_gradient :
129
+ vals ["dlogp" ] = var .dlogp ()(mx )
114
130
115
131
def message (name , values ):
116
132
if np .size (values ) < 10 :
0 commit comments