Skip to content

Commit f58327a

Browse files
committed
.refactor
1 parent 3feaf2e commit f58327a

File tree

2 files changed

+173
-32
lines changed

2 files changed

+173
-32
lines changed

pymc_experimental/gp/pytensor_gp.py

Lines changed: 33 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,24 @@
1212
class GPCovariance(OpFromGraph):
1313
"""OFG representing a GP covariance"""
1414

15+
@staticmethod
16+
def square_dist_Xs(X, Xs, ls):
17+
assert X.ndim == 2, "Complain to Bill about it"
18+
assert Xs.ndim == 2, "Complain to Bill about it"
19+
20+
X = X / ls
21+
Xs = Xs / ls
22+
23+
X2 = pt.sum(pt.square(X), axis=-1)
24+
Xs2 = pt.sum(pt.square(Xs), axis=-1)
25+
26+
sqd = -2.0 * X @ X.mT + (X2[..., :, None] + Xs2[..., None, :])
27+
# sqd = -2.0 * pt.dot(X, pt.transpose(Xs)) + (
28+
# pt.reshape(X2, (-1, 1)) + pt.reshape(Xs2, (1, -1))
29+
# )
30+
31+
return pt.clip(sqd, 0, pt.inf)
32+
1533
@staticmethod
1634
def square_dist(X, ls):
1735
X = X / ls
@@ -27,20 +45,27 @@ class ExpQuadCov(GPCovariance):
2745
"""
2846

2947
@classmethod
30-
def exp_quad_full(cls, X, ls):
31-
return pt.exp(-0.5 * cls.square_dist(X, ls))
48+
def exp_quad_full(cls, X, Xs, ls):
49+
return pt.exp(-0.5 * cls.square_dist_Xs(X, Xs, ls))
3250

3351
@classmethod
34-
def build_covariance(cls, X, ls):
52+
def build_covariance(cls, X, Xs=None, *, ls):
3553
X = pt.as_tensor(X)
54+
if Xs is None:
55+
Xs = X
56+
else:
57+
Xs = pt.as_tensor(Xs)
3658
ls = pt.as_tensor(ls)
3759

38-
ofg = cls(inputs=[X, ls], outputs=[cls.exp_quad_full(X, ls)])
39-
return ofg(X, ls)
60+
out = cls.exp_quad_full(X, Xs, ls)
61+
if Xs is X:
62+
return cls(inputs=[X, ls], outputs=[out])(X, ls)
63+
else:
64+
return cls(inputs=[X, Xs, ls], outputs=[out])(X, Xs, ls)
4065

4166

42-
def ExpQuad(X, ls):
43-
return ExpQuadCov.build_covariance(X, ls)
67+
def ExpQuad(X, X_new=None, *, ls):
68+
return ExpQuadCov.build_covariance(X, X_new, ls=ls)
4469

4570

4671
class WhiteNoiseCov(GPCovariance):
@@ -77,6 +102,7 @@ class GP(Continuous):
77102

78103
@classmethod
79104
def dist(cls, cov, **kwargs):
105+
# return Assert(msg="Don't know what a GP_RV is")(False)
80106
cov = pt.as_tensor(cov)
81107
mu = pt.zeros(cov.shape[-1])
82108
return super().dist([mu, cov], **kwargs)

tests/test_gp.py

Lines changed: 140 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,26 @@
11
import numpy as np
22
import pymc as pm
33
import pytensor.tensor as pt
4-
import pytest
54

65
from pymc_experimental.gp.pytensor_gp import GP, ExpQuad
76

87

98
def test_exp_quad():
109
x = pt.arange(3)[:, None]
1110
ls = pt.ones(())
12-
cov = ExpQuad.build_covariance(x, ls).eval()
11+
cov = ExpQuad(x, ls=ls).eval()
1312
expected_distance = np.array([[0.0, 1.0, 4.0], [1.0, 0.0, 1.0], [4.0, 1.0, 0.0]])
1413

1514
np.testing.assert_allclose(cov, np.exp(-0.5 * expected_distance))
1615

1716

18-
@pytest.fixture(scope="session")
19-
def marginal_model():
17+
# @pytest.fixture(scope="session")
18+
def latent_model():
2019
with pm.Model() as m:
2120
X = pm.Data("X", np.arange(3)[:, None])
2221
y = np.full(3, np.pi)
2322
ls = 1.0
24-
cov = ExpQuad(X, ls)
23+
cov = ExpQuad(X, ls=ls)
2524
gp = GP("gp", cov=cov)
2625

2726
sigma = 1.0
@@ -30,31 +29,147 @@ def marginal_model():
3029
return m
3130

3231

33-
def test_marginal_sigma_rewrites_to_white_noise_cov(marginal_model):
34-
obs = marginal_model["obs"]
32+
def latent_model_old_API():
33+
with pm.Model() as m:
34+
X = pm.Data("X", np.arange(3)[:, None])
35+
y = np.full(3, np.pi)
36+
ls = 1.0
37+
cov = pm.gp.cov.ExpQuad(1, ls)
38+
gp_class = pm.gp.Latent(cov_func=cov)
39+
gp = gp_class.prior("gp", X, reparameterize=False)
40+
41+
sigma = 1.0
42+
obs = pm.Normal("obs", mu=gp, sigma=sigma, observed=y)
3543

36-
# TODO: Bring these checks back after we implement marginalization of the GP RV
37-
#
38-
# assert sum(isinstance(var.owner.op, pm.Normal.rv_type)
39-
# for var in ancestors([obs])
40-
# if var.owner is not None) == 1
41-
#
42-
f = pm.compile_pymc([], obs)
43-
#
44-
# assert not any(isinstance(node.op, pm.Normal.rv_type) for node in f.maker.fgraph.apply_nodes)
44+
return m, gp_class
4545

46-
draws = np.stack([f() for _ in range(10_000)])
47-
empirical_cov = np.cov(draws.T)
4846

49-
expected_distance = np.array([[0.0, 1.0, 4.0], [1.0, 0.0, 1.0], [4.0, 1.0, 0.0]])
47+
def test_latent_model_prior():
48+
m = latent_model()
49+
ref_m, _ = latent_model_old_API()
50+
51+
prior = pm.draw(m["gp"], draws=1000)
52+
prior_ref = pm.draw(ref_m["gp"], draws=1000)
53+
54+
np.testing.assert_allclose(
55+
prior.mean(),
56+
prior_ref.mean(),
57+
atol=0.1,
58+
)
59+
60+
np.testing.assert_allclose(
61+
prior.std(),
62+
prior_ref.std(),
63+
rtol=0.1,
64+
)
65+
66+
67+
def test_latent_model_logp():
68+
m = latent_model()
69+
ip = m.initial_point()
70+
71+
ref_m, _ = latent_model_old_API()
72+
73+
np.testing.assert_allclose(
74+
m.compile_logp()(ip),
75+
ref_m.compile_logp()(ip),
76+
rtol=1e-6,
77+
)
78+
79+
80+
import arviz as az
81+
82+
83+
def gp_conditional(model, gp, Xnew, jitter=1e-6):
84+
def _build_conditional(self, Xnew, f, cov, jitter):
85+
X, ls = cov.owner.inputs
86+
87+
Kxx = cov
88+
Kxs = cov.owner.op.build_covariance(X, Xnew, ls=ls)
89+
Kss = cov.owner.op.build_covariance(Xnew, ls=ls)
90+
91+
L = pt.linalg.cholesky(Kxx + pt.eye(X.shape[0]) * jitter)
92+
# TODO: Use cho_solve
93+
A = pt.linalg.solve_triangular(L, Kxs, lower=True)
94+
v = pt.linalg.solve_triangular(L, f, lower=True)
95+
96+
mu = (A.mT @ v).T # Vector?
97+
cov = Kss - (A.mT @ A)
98+
99+
return mu, cov
100+
101+
with model.copy() as new_m:
102+
gp = new_m[gp.name]
103+
_, cov = gp.owner.op.dist_params(gp.owner)
104+
mu_star, cov_star = _build_conditional(None, Xnew, gp, cov, jitter)
105+
gp_star = pm.MvNormal("gp_star", mu_star, cov_star)
106+
return new_m
107+
108+
109+
def test_latent_model_predict_new_x():
110+
rng = np.random.default_rng(0)
111+
new_x = np.array([3, 4])[:, None]
112+
113+
m = latent_model()
114+
ref_m, ref_gp_class = latent_model_old_API()
115+
116+
posterior_idata = az.from_dict({"gp": rng.normal(np.pi, 1e-3, size=(4, 1000, 2))})
117+
118+
# with gp_extend_to_new_x(m):
119+
with gp_conditional(m, m["gp"], new_x):
120+
pred = (
121+
pm.sample_posterior_predictive(posterior_idata, var_names=["gp_star"])
122+
.posterior_predictiev["gp"]
123+
.values
124+
)
125+
126+
with ref_m:
127+
gp_star = ref_gp_class.conditional("gp_star", Xnew=new_x)
128+
pred_ref = (
129+
pm.sample_posterior_predictive(posterior_idata, var_names=["gp_star"])
130+
.posterior_predictive["gp"]
131+
.values
132+
)
133+
134+
np.testing.assert_allclose(
135+
pred.mean(),
136+
pred_ref.mean(),
137+
atol=0.1,
138+
)
50139

51140
np.testing.assert_allclose(
52-
empirical_cov, np.exp(-0.5 * expected_distance) + np.eye(3), atol=0.1, rtol=0.1
141+
pred.std(),
142+
pred_ref.std(),
143+
rtol=0.1,
53144
)
54145

55146

56-
def test_marginal_gp_logp(marginal_model):
57-
expected_logps = {"obs": -8.8778}
58-
point_logps = marginal_model.point_logps(round_vals=4)
59-
for v1, v2 in zip(point_logps.values(), expected_logps.values()):
60-
np.testing.assert_allclose(v1, v2, atol=1e-6)
147+
#
148+
# def test_marginal_sigma_rewrites_to_white_noise_cov(marginal_model, ):
149+
# obs = marginal_model["obs"]
150+
#
151+
# # TODO: Bring these checks back after we implement marginalization of the GP RV
152+
# #
153+
# # assert sum(isinstance(var.owner.op, pm.Normal.rv_type)
154+
# # for var in ancestors([obs])
155+
# # if var.owner is not None) == 1
156+
# #
157+
# f = pm.compile_pymc([], obs)
158+
# #
159+
# # assert not any(isinstance(node.op, pm.Normal.rv_type) for node in f.maker.fgraph.apply_nodes)
160+
#
161+
# draws = np.stack([f() for _ in range(10_000)])
162+
# empirical_cov = np.cov(draws.T)
163+
#
164+
# expected_distance = np.array([[0.0, 1.0, 4.0], [1.0, 0.0, 1.0], [4.0, 1.0, 0.0]])
165+
#
166+
# np.testing.assert_allclose(
167+
# empirical_cov, np.exp(-0.5 * expected_distance) + np.eye(3), atol=0.1, rtol=0.1
168+
# )
169+
#
170+
#
171+
# def test_marginal_gp_logp(marginal_model):
172+
# expected_logps = {"obs": -8.8778}
173+
# point_logps = marginal_model.point_logps(round_vals=4)
174+
# for v1, v2 in zip(point_logps.values(), expected_logps.values()):
175+
# np.testing.assert_allclose(v1, v2, atol=1e-6)

0 commit comments

Comments
 (0)