Skip to content

Commit 6d04966

Browse files
committed
Return elemwise logp in MvNormal
1 parent c6b786b commit 6d04966

File tree

1 file changed

+47
-23
lines changed

1 file changed

+47
-23
lines changed

pymc3/distributions/multivariate.py

Lines changed: 47 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,9 @@ class MvNormal(Continuous):
9999
def __init__(self, mu, cov=None, tau=None, chol=None, lower=True,
100100
*args, **kwargs):
101101
super(MvNormal, self).__init__(*args, **kwargs)
102+
if len(self.shape) > 2:
103+
raise ValueError("Only 1 or 2 dimensions are allowed.")
104+
102105
if not lower:
103106
chol = chol.T
104107
if len([i for i in [tau, cov, chol] if i is not None]) != 1:
@@ -122,6 +125,7 @@ def __init__(self, mu, cov=None, tau=None, chol=None, lower=True,
122125
raise ValueError('cov must be two dimensional.')
123126
self.chol_cov = cholesky(cov)
124127
self.cov = cov
128+
self._n = self.cov.shape[-1]
125129
elif tau is not None:
126130
self.k = tau.shape[0]
127131
self._cov_type = 'tau'
@@ -130,13 +134,15 @@ def __init__(self, mu, cov=None, tau=None, chol=None, lower=True,
130134
raise ValueError('tau must be two dimensional.')
131135
self.chol_tau = cholesky(tau)
132136
self.tau = tau
137+
self._n = self.tau.shape[-1]
133138
else:
134139
self.k = chol.shape[0]
135140
self._cov_type = 'chol'
136141
if chol.ndim != 2:
137142
raise ValueError('chol must be two dimensional.')
138143
self.chol_cov = tt.as_tensor_variable(chol)
139-
144+
self._n = self.chol_cov.shape[-1]
145+
140146
def random(self, point=None, size=None):
141147
if size is None:
142148
size = []
@@ -148,6 +154,9 @@ def random(self, point=None, size=None):
148154

149155
if self._cov_type == 'cov':
150156
mu, cov = draw_values([self.mu, self.cov], point=point)
157+
if mu.shape != cov[0].shape:
158+
raise ValueError("Shapes for mu an cov don't match")
159+
151160
try:
152161
dist = stats.multivariate_normal(
153162
mean=mu, cov=cov, allow_singular=True)
@@ -157,11 +166,17 @@ def random(self, point=None, size=None):
157166
return dist.rvs(size)
158167
elif self._cov_type == 'chol':
159168
mu, chol = draw_values([self.mu, self.chol_cov], point=point)
169+
if mu.shape != chol[0].shape:
170+
raise ValueError("Shapes for mu an chol don't match")
171+
160172
size.append(mu.shape[0])
161173
standard_normal = np.random.standard_normal(size)
162174
return mu + np.dot(standard_normal, chol.T)
163175
else:
164176
mu, tau = draw_values([self.mu, self.tau], point=point)
177+
if mu.shape != tau[0].shape:
178+
raise ValueError("Shapes for mu an tau don't match")
179+
165180
size.append(mu.shape[0])
166181
standard_normal = np.random.standard_normal(size)
167182
try:
@@ -171,23 +186,32 @@ def random(self, point=None, size=None):
171186
transformed = linalg.solve_triangular(
172187
chol, standard_normal.T, lower=True)
173188
return mu + transformed.T
174-
189+
175190
def logp(self, value):
176191
mu = self.mu
177-
k = mu.shape[-1]
178-
179-
value = value.reshape((-1, k))
192+
if value.ndim > 2 or value.ndim == 0:
193+
raise ValueError('Invalid dimension for value: %s' % value.ndim)
194+
if value.ndim == 1:
195+
onedim = True
196+
value = value[None, :]
197+
else:
198+
onedim = False
199+
180200
delta = value - mu
181-
201+
182202
if self._cov_type == 'cov':
183203
# Use this when Theano#5908 is released.
184204
# return MvNormalLogp()(self.cov, delta)
185-
return self._logp_cov(delta)
205+
logp = self._logp_cov(delta)
186206
elif self._cov_type == 'tau':
187-
return self._logp_tau(delta)
207+
logp = self._logp_tau(delta)
188208
else:
189-
return self._logp_chol(delta)
190-
209+
logp = self._logp_chol(delta)
210+
211+
if onedim:
212+
return logp[0]
213+
return logp
214+
191215
def _logp_chol(self, delta):
192216
chol_cov = self.chol_cov
193217
n, k = delta.shape
@@ -200,10 +224,10 @@ def _logp_chol(self, delta):
200224
chol_cov = tt.switch(ok, chol_cov, 1)
201225

202226
delta_trans = self.solve_lower(chol_cov, delta.T)
203-
204-
result = n * k * np.log(2 * np.pi)
205-
result += 2.0 * n * tt.sum(tt.log(diag))
206-
result += (delta_trans ** 2).sum()
227+
228+
result = k * np.log(2 * np.pi)
229+
result += 2.0 * tt.sum(tt.log(diag))
230+
result += (delta_trans ** 2).sum(axis=0)
207231
result = -0.5 * result
208232
return bound(result, ok)
209233

@@ -217,10 +241,10 @@ def _logp_cov(self, delta):
217241
chol_cov = tt.switch(ok, chol_cov, 1)
218242
diag = tt.nlinalg.diag(chol_cov)
219243
delta_trans = self.solve_lower(chol_cov, delta.T)
220-
221-
result = n * k * tt.log(2 * np.pi)
222-
result += 2.0 * n * tt.sum(tt.log(diag))
223-
result += (delta_trans ** 2).sum()
244+
245+
result = k * tt.log(2 * np.pi)
246+
result += 2.0 * tt.sum(tt.log(diag))
247+
result += (delta_trans ** 2).sum(axis=0)
224248
result = -0.5 * result
225249
return bound(result, ok)
226250

@@ -234,10 +258,10 @@ def _logp_tau(self, delta):
234258
chol_tau = tt.switch(ok, chol_tau, 1)
235259
diag = tt.nlinalg.diag(chol_tau)
236260
delta_trans = tt.dot(chol_tau.T, delta.T)
237-
238-
result = n * k * tt.log(2 * np.pi)
239-
result -= 2.0 * n * tt.sum(tt.log(diag))
240-
result += (delta_trans ** 2).sum()
261+
262+
result = k * tt.log(2 * np.pi)
263+
result -= 2.0 * tt.sum(tt.log(diag))
264+
result += (delta_trans ** 2).sum(axis=0)
241265
result = -0.5 * result
242266
return bound(result, ok)
243267

@@ -247,7 +271,7 @@ def _repr_latex_(self, name=None, dist=None):
247271
mu = dist.mu
248272
try:
249273
cov = dist.cov
250-
except AttributeErrir:
274+
except AttributeError:
251275
cov = dist.chol_cov
252276
return r'${} \sim \text{{MvNormal}}(\mathit{{mu}}={}, \mathit{{cov}}={})$'.format(name,
253277
get_variable_name(mu),

0 commit comments

Comments
 (0)