Skip to content

Commit 866c482

Browse files
aseyboldtJunpeng Lao
authored andcommitted
Add logp_nojac and logp_sum (#2499)
* Add logp_nojac and logp_sum * Add model name to name of logp variable
1 parent 6e02dbc commit 866c482

File tree

4 files changed

+101
-10
lines changed

4 files changed

+101
-10
lines changed

pymc3/distributions/distribution.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,28 @@ def _repr_latex_(self, name=None, dist=None):
9191
"""Magic method name for IPython to use for LaTeX formatting."""
9292
return None
9393

94+
def logp_nojac(self, *args, **kwargs):
95+
"""Return the logp, but do not include a jacobian term for transforms.
96+
97+
If we use different parametrizations for the same distribution, we
98+
need to add the determinant of the jacobian of the transformation
99+
to make sure the densities still describe the same distribution.
100+
However, MAP estimates are not invariant with respect to the
101+
parametrization, we need to exclude the jacobian terms in this case.
102+
103+
This function should be overwritten in base classes for transformed
104+
distributions.
105+
"""
106+
return self.logp(*args, **kwargs)
107+
108+
def logp_sum(self, *args, **kwargs):
109+
"""Return the sum of the logp values for the given observations.
110+
111+
Subclasses can use this to improve the speed of logp evaluations
112+
if only the sum of the logp values is needed.
113+
"""
114+
return tt.sum(self.logp(*args, **kwargs))
115+
94116
__latex__ = _repr_latex_
95117

96118

pymc3/distributions/transforms.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,9 @@ def logp(self, x):
8080
return (self.dist.logp(self.transform_used.backward(x)) +
8181
self.transform_used.jacobian_det(x))
8282

83+
def logp_nojac(self, x):
84+
return self.dist.logp(self.transform_used.backward(x))
85+
8386
transform = Transform
8487

8588

pymc3/model.py

Lines changed: 70 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,18 @@ def d2logp(self, vars=None):
172172
"""Compiled log probability density hessian function"""
173173
return self.model.fn(hessian(self.logpt, vars))
174174

175+
@property
176+
def logp_nojac(self):
177+
return self.model.fn(self.logp_nojact)
178+
179+
def dlogp_nojac(self, vars=None):
180+
"""Compiled log density gradient function, without jacobian terms."""
181+
return self.model.fn(gradient(self.logp_nojact, vars))
182+
183+
def d2logp_nojac(self, vars=None):
184+
"""Compiled log density hessian function, without jacobian terms."""
185+
return self.model.fn(hessian(self.logp_nojact, vars))
186+
175187
@property
176188
def fastlogp(self):
177189
"""Compiled log probability density function"""
@@ -185,13 +197,36 @@ def fastd2logp(self, vars=None):
185197
"""Compiled log probability density hessian function"""
186198
return self.model.fastfn(hessian(self.logpt, vars))
187199

200+
@property
201+
def fastlogp_nojac(self):
202+
return self.model.fastfn(self.logp_nojact)
203+
204+
def fastdlogp_nojac(self, vars=None):
205+
"""Compiled log density gradient function, without jacobian terms."""
206+
return self.model.fastfn(gradient(self.logp_nojact, vars))
207+
208+
def fastd2logp_nojac(self, vars=None):
209+
"""Compiled log density hessian function, without jacobian terms."""
210+
return self.model.fastfn(hessian(self.logp_nojact, vars))
211+
188212
@property
189213
def logpt(self):
190214
"""Theano scalar of log-probability of the model"""
191215
if getattr(self, 'total_size', None) is not None:
192-
logp = tt.sum(self.logp_elemwiset) * self.scaling
216+
logp = self.logp_sum_unscaledt * self.scaling
217+
else:
218+
logp = self.logp_sum_unscaledt
219+
if self.name is not None:
220+
logp.name = '__logp_%s' % self.name
221+
return logp
222+
223+
@property
224+
def logp_nojact(self):
225+
"""Theano scalar of log-probability, excluding jacobian terms."""
226+
if getattr(self, 'total_size', None) is not None:
227+
logp = tt.sum(self.logp_nojac_unscaledt) * self.scaling
193228
else:
194-
logp = tt.sum(self.logp_elemwiset)
229+
logp = tt.sum(self.logp_nojac_unscaledt)
195230
if self.name is not None:
196231
logp.name = '__logp_%s' % self.name
197232
return logp
@@ -626,9 +661,26 @@ def logp_dlogp_function(self, grad_vars=None, **kwargs):
626661
def logpt(self):
627662
"""Theano scalar of log-probability of the model"""
628663
with self:
629-
factors = [var.logpt for var in self.basic_RVs] + self.potentials
630-
logp = tt.add(*map(tt.sum, factors))
631-
logp.name = '__logp'
664+
factors = [var.logpt for var in self.basic_RVs]
665+
logp_factors = tt.sum(factors)
666+
logp_potentials = tt.sum([tt.sum(pot) for pot in self.potentials])
667+
logp = logp_factors + logp_potentials
668+
if self.name:
669+
logp.name = '__logp_%s' % self.name
670+
else:
671+
logp.name = '__logp'
672+
return logp
673+
674+
@property
675+
def logp_nojact(self):
676+
"""Theano scalar of log-probability of the model"""
677+
with self:
678+
factors = [var.logp_nojact for var in self.basic_RVs] + self.potentials
679+
logp = tt.sum([tt.sum(factor) for factor in factors])
680+
if self.name:
681+
logp.name = '__logp_nojac_%s' % self.name
682+
else:
683+
logp.name = '__logp_nojac'
632684
return logp
633685

634686
@property
@@ -637,7 +689,7 @@ def varlogpt(self):
637689
(excluding deterministic)."""
638690
with self:
639691
factors = [var.logpt for var in self.vars]
640-
return tt.add(*map(tt.sum, factors))
692+
return tt.sum(factors)
641693

642694
@property
643695
def vars(self):
@@ -1069,6 +1121,10 @@ def __init__(self, type=None, owner=None, index=None, name=None,
10691121
self.tag.test_value = np.ones(
10701122
distribution.shape, distribution.dtype) * distribution.default()
10711123
self.logp_elemwiset = distribution.logp(self)
1124+
# The logp might need scaling in minibatches.
1125+
# This is done in `Factor`.
1126+
self.logp_sum_unscaledt = distribution.logp_sum(self)
1127+
self.logp_nojac_unscaledt = distribution.logp_nojac(self)
10721128
self.total_size = total_size
10731129
self.model = model
10741130
self.scaling = _get_scaling(total_size, self.shape, self.ndim)
@@ -1172,6 +1228,10 @@ def __init__(self, type=None, owner=None, index=None, name=None, data=None,
11721228

11731229
self.missing_values = data.missing_values
11741230
self.logp_elemwiset = distribution.logp(data)
1231+
# The logp might need scaling in minibatches.
1232+
# This is done in `Factor`.
1233+
self.logp_sum_unscaledt = distribution.logp_sum(data)
1234+
self.logp_nojac_unscaledt = distribution.logp_nojac(data)
11751235
self.total_size = total_size
11761236
self.model = model
11771237
self.distribution = distribution
@@ -1223,6 +1283,10 @@ def __init__(self, name, data, distribution, total_size=None, model=None):
12231283
self.missing_values = [datum.missing_values for datum in self.data.values()
12241284
if datum.missing_values is not None]
12251285
self.logp_elemwiset = distribution.logp(**self.data)
1286+
# The logp might need scaling in minibatches.
1287+
# This is done in `Factor`.
1288+
self.logp_sum_unscaledt = distribution.logp_sum(**self.data)
1289+
self.logp_nojac_unscaledt = distribution.logp_nojac(**self.data)
12261290
self.total_size = total_size
12271291
self.model = model
12281292
self.distribution = distribution

pymc3/tuning/starting.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,10 @@
1818

1919
__all__ = ['find_MAP']
2020

21+
2122
def find_MAP(start=None, vars=None, fmin=None,
22-
return_raw=False, model=None, live_disp=False, callback=None, *args, **kwargs):
23+
return_raw=False, model=None, live_disp=False, callback=None,
24+
*args, **kwargs):
2325
"""
2426
Sets state to the local maximum a posteriori point given a model.
2527
Current default of fmin_Hessian does not deal well with optimizing close
@@ -69,7 +71,7 @@ def find_MAP(start=None, vars=None, fmin=None,
6971
except AttributeError:
7072
gradient_avail = False
7173

72-
if disc_vars or not gradient_avail :
74+
if disc_vars or not gradient_avail:
7375
pm._log.warning("Warning: gradient not available." +
7476
"(E.g. vars contains discrete variables). MAP " +
7577
"estimates may not be accurate for the default " +
@@ -88,13 +90,13 @@ def find_MAP(start=None, vars=None, fmin=None,
8890
start = Point(start, model=model)
8991
bij = DictToArrayBijection(ArrayOrdering(vars), start)
9092

91-
logp = bij.mapf(model.fastlogp)
93+
logp = bij.mapf(model.fastlogp_nojac)
9294
def logp_o(point):
9395
return nan_to_high(-logp(point))
9496

9597
# Check to see if minimization function actually uses the gradient
9698
if 'fprime' in getargspec(fmin).args:
97-
dlogp = bij.mapf(model.fastdlogp(vars))
99+
dlogp = bij.mapf(model.fastdlogp_nojac(vars))
98100
def grad_logp_o(point):
99101
return nan_to_num(-dlogp(point))
100102

0 commit comments

Comments
 (0)