Skip to content

Commit 1e6b6e9

Browse files
authored
mv normal refactor, fix float64 (#2329)
(cherry picked from commit 52380f6)
1 parent 748c61c commit 1e6b6e9

File tree

1 file changed

+9
-7
lines changed

1 file changed

+9
-7
lines changed

pymc3/distributions/multivariate.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -214,8 +214,8 @@ def logp(self, value):
214214

215215
def _logp_chol(self, delta):
216216
chol_cov = self.chol_cov
217-
n, k = delta.shape
218-
217+
_, k = delta.shape
218+
k = pm.floatX(k)
219219
diag = tt.nlinalg.diag(chol_cov)
220220
# Check if the covariance matrix is positive definite.
221221
ok = tt.all(diag > 0)
@@ -225,15 +225,16 @@ def _logp_chol(self, delta):
225225

226226
delta_trans = self.solve_lower(chol_cov, delta.T)
227227

228-
result = k * np.log(2 * np.pi)
228+
result = k * pm.floatX(np.log(2. * np.pi))
229229
result += 2.0 * tt.sum(tt.log(diag))
230230
result += (delta_trans ** 2).sum(axis=0)
231231
result = -0.5 * result
232232
return bound(result, ok)
233233

234234
def _logp_cov(self, delta):
235235
chol_cov = self.chol_cov
236-
n, k = delta.shape
236+
_, k = delta.shape
237+
k = pm.floatX(k)
237238

238239
diag = tt.nlinalg.diag(chol_cov)
239240
ok = tt.all(diag > 0)
@@ -242,15 +243,16 @@ def _logp_cov(self, delta):
242243
diag = tt.nlinalg.diag(chol_cov)
243244
delta_trans = self.solve_lower(chol_cov, delta.T)
244245

245-
result = k * tt.log(2 * np.pi)
246+
result = k * pm.floatX(np.log(2. * np.pi))
246247
result += 2.0 * tt.sum(tt.log(diag))
247248
result += (delta_trans ** 2).sum(axis=0)
248249
result = -0.5 * result
249250
return bound(result, ok)
250251

251252
def _logp_tau(self, delta):
252253
chol_tau = self.chol_tau
253-
n, k = delta.shape
254+
_, k = delta.shape
255+
k = pm.floatX(k)
254256

255257
diag = tt.nlinalg.diag(chol_tau)
256258
ok = tt.all(diag > 0)
@@ -259,7 +261,7 @@ def _logp_tau(self, delta):
259261
diag = tt.nlinalg.diag(chol_tau)
260262
delta_trans = tt.dot(chol_tau.T, delta.T)
261263

262-
result = k * tt.log(2 * np.pi)
264+
result = k * pm.floatX(np.log(2. * np.pi))
263265
result -= 2.0 * tt.sum(tt.log(diag))
264266
result += (delta_trans ** 2).sum(axis=0)
265267
result = -0.5 * result

0 commit comments

Comments
 (0)