Skip to content

Commit 4819184

Browse files
committed
Implement Multivariate Laplace distribution
1 parent 0f5a818 commit 4819184

File tree

3 files changed

+193
-0
lines changed

3 files changed

+193
-0
lines changed

pymc_experimental/distributions/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
Skellam,
2525
)
2626
from pymc_experimental.distributions.histogram_utils import histogram_approximation
27+
from pymc_experimental.distributions.multivariate.laplace import MvLaplace
2728
from pymc_experimental.distributions.multivariate.r2d2m2cp import R2D2M2CP
2829
from pymc_experimental.distributions.timeseries import DiscreteMarkovChain
2930

@@ -32,6 +33,7 @@
3233
"DiscreteMarkovChain",
3334
"GeneralizedPoisson",
3435
"GenExtreme",
36+
"MvLaplace",
3537
"R2D2M2CP",
3638
"Skellam",
3739
"histogram_approximation",
Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
import numpy as np
2+
import pytensor.tensor as pt
3+
import scipy
4+
5+
from pymc.distributions.dist_math import check_parameters
6+
from pymc.distributions.distribution import Continuous, SymbolicRandomVariable
7+
from pymc.distributions.multivariate import quaddist_chol, quaddist_matrix
8+
from pymc.distributions.shape_utils import implicit_size_from_params, rv_size_is_none
9+
from pymc.pytensorf import normalize_rng_param
10+
from pytensor.gradient import grad_not_implemented
11+
from pytensor.scalar import BinaryScalarOp, upgrade_to_float
12+
from pytensor.tensor.random.utils import normalize_size_param
13+
14+
15+
class Kv(BinaryScalarOp):
16+
"""
17+
Modified Bessel function of the second kind of real order v.
18+
"""
19+
20+
nfunc_spec = ("scipy.special.kv", 2, 1)
21+
22+
@staticmethod
23+
def st_impl(v, x):
24+
return scipy.special.kv(v, x)
25+
26+
def impl(self, v, x):
27+
return self.st_impl(v, x)
28+
29+
def grad(self, inputs, grads):
30+
v, x = inputs
31+
(gz,) = grads
32+
return [grad_not_implemented(self, 0, v), gz * kvp(v, x)]
33+
34+
def c_code(self, *args, **kwargs):
35+
raise NotImplementedError()
36+
37+
38+
kv = Kv(upgrade_to_float, name="kv")
39+
40+
41+
class Kvp(BinaryScalarOp):
42+
"""
43+
First-order derivative of real-order Modified Bessel function of the second kind Kv(z)
44+
"""
45+
46+
nfunc_spec = ("scipy.special.kvp", 2, 1)
47+
48+
@staticmethod
49+
def st_impl(v, x):
50+
return scipy.special.kvp(v, x)
51+
52+
def impl(self, v, x):
53+
return self.st_impl(v, x)
54+
55+
def c_code(self, *args, **kwargs):
56+
raise NotImplementedError()
57+
58+
59+
kvp = Kvp(upgrade_to_float, name="kvp")
60+
61+
62+
class MultivariateLaplaceRV(SymbolicRandomVariable):
63+
name = "multivariate_laplace"
64+
extended_signature = "[rng],[size],(m),(m,m)->[rng],(m)"
65+
_print_name = ("MultivariateLaplace", "\\operatorname{MultivariateLaplace}")
66+
67+
@classmethod
68+
def rv_op(cls, mu, cov, *, size=None, rng=None):
69+
mu = pt.as_tensor(mu)
70+
cov = pt.as_tensor(cov)
71+
rng = normalize_rng_param(rng)
72+
size = normalize_size_param(size)
73+
74+
assert mu.type.ndim >= 1
75+
assert cov.type.ndim >= 2
76+
77+
if rv_size_is_none(size):
78+
size = implicit_size_from_params(mu, cov, ndims_params=(1, 2))
79+
80+
next_rng, e = pt.random.exponential(size=size, rng=rng).owner.outputs
81+
next_rng, z = pt.random.multivariate_normal(
82+
mean=pt.zeros(mu.shape[-1]), cov=cov, size=size, rng=next_rng
83+
).owner.outputs
84+
rv = mu + pt.sqrt(e)[..., None] * z
85+
86+
return cls(
87+
inputs=[rng, size, mu, cov],
88+
outputs=[next_rng, rv],
89+
)(rng, size, mu, cov)
90+
91+
92+
class MvLaplace(Continuous):
93+
r"""Multivariate (Symmetric) Laplace distribution."""
94+
95+
rv_type = MultivariateLaplaceRV
96+
rv_op = MultivariateLaplaceRV.rv_op
97+
98+
@classmethod
99+
def dist(cls, mu=0, cov=None, *, tau=None, chol=None, lower=True, **kwargs):
100+
cov = quaddist_matrix(cov, chol, tau, lower)
101+
102+
mu = pt.as_tensor_variable(mu)
103+
if mu.type.broadcastable[-1] != cov.type.broadcastable[-1]:
104+
mu, _ = pt.broadcast_arrays(mu, cov[..., -1])
105+
return super().dist([mu, cov], **kwargs)
106+
107+
def support_point(rv, size, mu, cov):
108+
if rv_size_is_none(size):
109+
broadcasted_mu, _ = pt.random.utils.broadcast_params([mu, cov], ndims_params=[1, 2])
110+
else:
111+
broadcast_shape = pt.concatenate([size, [mu.shape[-1]]])
112+
broadcasted_mu = pt.broadcast_to(mu, broadcast_shape)
113+
return broadcasted_mu
114+
115+
def logp(value, mu, cov):
116+
quaddist, logdet, posdef = quaddist_chol(value, mu, cov)
117+
118+
k = value.shape[-1].astype("floatX")
119+
norm = np.log(2) - 0.5 * k * np.log(2 * np.pi) - logdet
120+
121+
v = 1 - (k / 2)
122+
kernel = ((v / 2) * pt.log(quaddist / 2)) + pt.log(kv(v, pt.sqrt(2 * quaddist)))
123+
124+
logp_val = norm + kernel
125+
return check_parameters(logp_val, posdef, msg="posdef scale")
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
import numpy as np
2+
import pymc as pm
3+
import scipy
4+
5+
from pymc_experimental.distributions.multivariate.laplace import MvLaplace
6+
7+
# TODO: Test mvlaplace_support_point
8+
9+
10+
def test_mvlaplace_random():
11+
mu = [-1, np.pi, 1]
12+
cov = [[1, 0.5, 0.25], [0.5, 2, 0.5], [0.25, 0.5, 3]]
13+
rv = MvLaplace.dist(mu=mu, cov=cov, size=10_000)
14+
15+
samples = pm.draw(rv, random_seed=13)
16+
assert samples.shape == (10_000, 3)
17+
np.testing.assert_allclose(np.mean(samples, axis=0), mu, rtol=0.05)
18+
np.testing.assert_allclose(np.cov(samples, rowvar=False), cov, rtol=0.1)
19+
20+
21+
def test_laplace_logp():
22+
# Testing against special bivariate cases described in:
23+
# https://en.wikipedia.org/wiki/Multivariate_Laplace_distribution#Probability_density_function
24+
25+
# Zero mean, non-identity covariance case
26+
mu = np.zeros(2)
27+
s1 = 0.5
28+
s2 = 2.0
29+
r = -0.25
30+
cov = np.array(
31+
[
32+
[s1**2, r * s1 * s2],
33+
[r * s1 * s2, s2**2],
34+
]
35+
)
36+
rv = MvLaplace.dist(mu=mu, cov=cov)
37+
rv_val = np.random.normal(size=(2,))
38+
logp_eval = pm.logp(rv, rv_val).eval()
39+
40+
x1, x2 = rv_val
41+
logp_expected = np.log(
42+
(1 / (np.pi * s1 * s2 * np.sqrt(1 - r**2)))
43+
* scipy.special.kv(
44+
0,
45+
np.sqrt(
46+
(2 * ((x1**2 / s1**2) - (2 * r * x1 * x2 / (s1 * s2)) + (x2**2 / s2**2)))
47+
/ (1 - r**2)
48+
),
49+
)
50+
)
51+
np.testing.assert_allclose(
52+
logp_eval,
53+
logp_expected,
54+
)
55+
56+
# Non zero mean, identity covariance case
57+
mu = np.array([1, 3])
58+
rv = MvLaplace.dist(mu=mu, cov=np.eye(2))
59+
rv_val = np.random.normal(size=(2,))
60+
logp_eval = pm.logp(rv, rv_val).eval()
61+
62+
logp_expected = np.log(1 / np.pi * scipy.special.kv(0, np.sqrt(2 * np.sum((rv_val - mu) ** 2))))
63+
np.testing.assert_allclose(
64+
logp_eval,
65+
logp_expected,
66+
)

0 commit comments

Comments
 (0)