diff --git a/pymc/distributions/multivariate.py b/pymc/distributions/multivariate.py index 4415f5ec2..1856d33c9 100644 --- a/pymc/distributions/multivariate.py +++ b/pymc/distributions/multivariate.py @@ -338,10 +338,7 @@ def rv_op(cls, mean, tau, *, method: str = "cholesky", rng=None, size=None): method=method, )(rng, size, mean, tau) - -@_logprob.register -def precision_mv_normal_logp(op: PrecisionMvNormalRV, value, rng, size, mean, tau, **kwargs): - [value] = value +def _precision_mv_normal_logp(value, mean, tau): k = value.shape[-1].astype("floatX") delta = value - mean @@ -349,6 +346,14 @@ def precision_mv_normal_logp(op: PrecisionMvNormalRV, value, rng, size, mean, ta logdet, posdef = _logdet_from_cholesky(nan_lower_cholesky(tau)) logp = -0.5 * (k * pt.log(2 * np.pi) + quadratic_form) + logdet + return logp, posdef + +@_logprob.register +def precision_mv_normal_logp(op: PrecisionMvNormalRV, value, rng, size, mean, tau, **kwargs): + [value] = value + + logp, posdef = _precision_mv_normal_logp(value, mean, tau) + return check_parameters( logp, posdef,