Skip to content

Commit 3b13a69

Browse files
refactor: move _precition_mv_normal_logp into a seperate function
1 parent 0960323 commit 3b13a69

File tree

1 file changed

+9
-4
lines changed

1 file changed

+9
-4
lines changed

pymc/distributions/multivariate.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -338,17 +338,22 @@ def rv_op(cls, mean, tau, *, method: str = "cholesky", rng=None, size=None):
338338
method=method,
339339
)(rng, size, mean, tau)
340340

341-
342-
@_logprob.register
343-
def precision_mv_normal_logp(op: PrecisionMvNormalRV, value, rng, size, mean, tau, **kwargs):
344-
[value] = value
341+
def _precision_mv_normal_logp(value, mean, tau):
345342
k = value.shape[-1].astype("floatX")
346343

347344
delta = value - mean
348345
quadratic_form = delta.T @ tau @ delta
349346
logdet, posdef = _logdet_from_cholesky(nan_lower_cholesky(tau))
350347
logp = -0.5 * (k * pt.log(2 * np.pi) + quadratic_form) + logdet
351348

349+
return logp, posdef
350+
351+
@_logprob.register
352+
def precision_mv_normal_logp(op: PrecisionMvNormalRV, value, rng, size, mean, tau, **kwargs):
353+
[value] = value
354+
355+
logp, posdef = _precision_mv_normal_logp(value, mean, tau)
356+
352357
return check_parameters(
353358
logp,
354359
posdef,

0 commit comments

Comments
 (0)