Skip to content

Commit 8b94a99

Browse files
refactor: moved _precision_mv_normal_logp into pmx
1 parent 57a7935 commit 8b94a99

File tree

1 file changed

+12
-1
lines changed

1 file changed

+12
-1
lines changed

pymc_extras/model/marginal/distributions.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from pymc.distributions import Bernoulli, Categorical, DiscreteUniform
1010
from pymc.distributions.distribution import _support_point, support_point
11-
from pymc.distributions.multivariate import _precision_mv_normal_logp
11+
from pymc.distributions.multivariate import _logdet_from_cholesky, nan_lower_cholesky
1212
from pymc.logprob.abstract import MeasurableOp, _logprob
1313
from pymc.logprob.basic import conditional_logp, logp
1414
from pymc.pytensorf import constant_fold
@@ -393,6 +393,17 @@ def step_alpha(logp_emission, log_alpha, log_P):
393393
return joint_logp, *dummy_logps
394394

395395

396+
def _precision_mv_normal_logp(value, mean, tau):
397+
k = value.shape[-1].astype("floatX")
398+
399+
delta = value - mean
400+
quadratic_form = delta.T @ tau @ delta
401+
logdet, posdef = _logdet_from_cholesky(nan_lower_cholesky(tau))
402+
logp = -0.5 * (k * pt.log(2 * np.pi) + quadratic_form) + logdet
403+
404+
return logp, posdef
405+
406+
396407
@_logprob.register(MarginalLaplaceRV)
397408
def laplace_marginal_rv_logp(op: MarginalLaplaceRV, values, *inputs, **kwargs):
398409
# Clone the inner RV graph of the Marginalized RV

0 commit comments

Comments
 (0)