Skip to content

Commit c7e84fb

Browse files
committed
Remove dead code
1 parent e9c1a9c commit c7e84fb

File tree

2 files changed

+20
-142
lines changed

2 files changed

+20
-142
lines changed

pymc_experimental/gp/pytensor_gp.py

Lines changed: 3 additions & 124 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ class GPCovariance(OpFromGraph):
1313
"""OFG representing a GP covariance"""
1414

1515
@staticmethod
16-
def square_dist_Xs(X, Xs, ls):
16+
def square_dist(X, Xs, ls):
1717
assert X.ndim == 2, "Complain to Bill about it"
1818
assert Xs.ndim == 2, "Complain to Bill about it"
1919

@@ -24,20 +24,8 @@ def square_dist_Xs(X, Xs, ls):
2424
Xs2 = pt.sum(pt.square(Xs), axis=-1)
2525

2626
sqd = -2.0 * X @ Xs.mT + (X2[..., :, None] + Xs2[..., None, :])
27-
# sqd = -2.0 * pt.dot(X, pt.transpose(Xs)) + (
28-
# pt.reshape(X2, (-1, 1)) + pt.reshape(Xs2, (1, -1))
29-
# )
30-
3127
return pt.clip(sqd, 0, pt.inf)
3228

33-
@staticmethod
34-
def square_dist(X, ls):
35-
X = X / ls
36-
X2 = pt.sum(pt.square(X), axis=-1)
37-
sqd = -2.0 * X @ X.mT + (X2[..., :, None] + X2[..., None, :])
38-
39-
return sqd
40-
4129

4230
class ExpQuadCov(GPCovariance):
4331
"""
@@ -46,7 +34,7 @@ class ExpQuadCov(GPCovariance):
4634

4735
@classmethod
4836
def exp_quad_full(cls, X, Xs, ls):
49-
return pt.exp(-0.5 * cls.square_dist_Xs(X, Xs, ls))
37+
return pt.exp(-0.5 * cls.square_dist(X, Xs, ls))
5038

5139
@classmethod
5240
def build_covariance(cls, X, Xs=None, *, ls):
@@ -64,32 +52,10 @@ def build_covariance(cls, X, Xs=None, *, ls):
6452
return cls(inputs=[X, Xs, ls], outputs=[out])(X, Xs, ls)
6553

6654

67-
def ExpQuad(X, X_new=None, *, ls):
55+
def ExpQuad(X, X_new=None, *, ls=1.0):
6856
return ExpQuadCov.build_covariance(X, X_new, ls=ls)
6957

7058

71-
# class WhiteNoiseCov(GPCovariance):
72-
# @classmethod
73-
# def white_noise_full(cls, X, sigma):
74-
# X_shape = tuple(X.shape)
75-
# shape = X_shape[:-1] + (X_shape[-2],)
76-
#
77-
# return _delta(shape, normalize_axis_tuple((-1, -2), X.ndim)) * sigma**2
78-
#
79-
# @classmethod
80-
# def build_covariance(cls, X, sigma):
81-
# X = pt.as_tensor(X)
82-
# sigma = pt.as_tensor(sigma)
83-
#
84-
# ofg = cls(inputs=[X, sigma], outputs=[cls.white_noise_full(X, sigma)])
85-
# return ofg(X, sigma)
86-
87-
#
88-
# def WhiteNoise(X, sigma):
89-
# return WhiteNoiseCov.build_covariance(X, sigma)
90-
#
91-
92-
9359
class GP_RV(pm.MvNormal.rv_type):
9460
name = "gaussian_process"
9561
signature = "(n),(n,n)->(n)"
@@ -103,7 +69,6 @@ class GP(Continuous):
10369

10470
@classmethod
10571
def dist(cls, cov, **kwargs):
106-
# return Assert(msg="Don't know what a GP_RV is")(False)
10772
cov = pt.as_tensor(cov)
10873
mu = pt.zeros(cov.shape[-1])
10974
return super().dist([mu, cov], **kwargs)
@@ -190,89 +155,3 @@ def _build_conditional(Xnew, f, cov, jitter):
190155
fgraph.add_output(gp_model_var_star, import_missing=True)
191156

192157
return model_from_fgraph(fgraph, mutate_fgraph=True)
193-
194-
195-
# @register_canonicalize
196-
# @node_rewriter(tracks=[pm.MvNormal.rv_type])
197-
# def GP_normal_mvnormal_conjugacy(fgraph: FunctionGraph, node):
198-
# # TODO: Should this alert users that it can't be applied when the GP is in a deterministic?
199-
# gp_rng, gp_size, mu, cov = node.inputs
200-
# next_gp_rng, gp_rv = node.outputs
201-
#
202-
# if not isinstance(cov.owner.op, GPCovariance):
203-
# return
204-
#
205-
# for client, input_index in fgraph.clients[gp_rv]:
206-
# # input_index is 2 because it goes (rng, size, mu, sigma), and we want the mu
207-
# # to be the GP we're looking
208-
# if isinstance(client.op, pm.Normal.rv_type) and (input_index == 2):
209-
# next_normal_rng, normal_rv = client.outputs
210-
# normal_rng, normal_size, mu, sigma = client.inputs
211-
#
212-
# if normal_rv.ndim != gp_rv.ndim:
213-
# return
214-
#
215-
# X = cov.owner.inputs[0]
216-
#
217-
# white_noise = WhiteNoiseCov.build_covariance(X, sigma)
218-
# white_noise.name = 'WhiteNoiseCov'
219-
# cov = cov + white_noise
220-
#
221-
# if not rv_size_is_none(normal_size):
222-
# normal_size = tuple(normal_size)
223-
# new_gp_size = normal_size[:-1]
224-
# core_shape = normal_size[-1]
225-
#
226-
# cov_shape = (*(None,) * (cov.ndim - 2), core_shape, core_shape)
227-
# cov = pt.specify_shape(cov, cov_shape)
228-
#
229-
# else:
230-
# new_gp_size = None
231-
#
232-
# next_new_gp_rng, new_gp_mvn = pm.MvNormal.dist(cov=cov, rng=gp_rng, size=new_gp_size).owner.outputs
233-
# new_gp_mvn.name = 'NewGPMvn'
234-
#
235-
# # Check that the new shape is at least as specific as the shape we are replacing
236-
# for new_shape, old_shape in zip(new_gp_mvn.type.shape, normal_rv.type.shape, strict=True):
237-
# if new_shape is None:
238-
# assert old_shape is None
239-
#
240-
# return {
241-
# next_normal_rng: next_new_gp_rng,
242-
# normal_rv: new_gp_mvn,
243-
# next_gp_rng: next_new_gp_rng
244-
# }
245-
#
246-
# else:
247-
# return None
248-
#
249-
# #TODO: Why do I need to register this twice?
250-
# specialization_ir_rewrites_db.register(
251-
# GP_normal_mvnormal_conjugacy.__name__,
252-
# GP_normal_mvnormal_conjugacy,
253-
# "basic",
254-
# )
255-
256-
# @node_rewriter(tracks=[pm.MvNormal.rv_type])
257-
# def GP_normal_marginal_logp(fgraph: FunctionGraph, node):
258-
# """
259-
# Replace Normal(GP(cov), sigma) -> MvNormal(0, cov + diag(sigma)).
260-
# """
261-
# rng, size, mu, cov = node.inputs
262-
# if cov.owner and cov.owner.op == matrix_inverse:
263-
# tau = cov.owner.inputs[0]
264-
# return PrecisionMvNormalRV.rv_op(mu, tau, size=size, rng=rng).owner.outputs
265-
# return None
266-
#
267-
268-
# cov_op = GPCovariance()
269-
# gp_op = GP("vanilla")
270-
# # SymbolicRandomVariable.register(type(gp_op))
271-
# prior_from_gp = PriorFromGP()
272-
#
273-
# MeasurableVariable.register(type(prior_from_gp))
274-
#
275-
#
276-
# @_get_measurable_outputs.register(type(prior_from_gp))
277-
# def gp_measurable_outputs(op, node):
278-
# return node.outputs

tests/test_gp.py

Lines changed: 17 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -7,17 +7,7 @@
77
from pymc_experimental.gp.pytensor_gp import GP, ExpQuad, conditional_gp
88

99

10-
def test_exp_quad():
11-
x = pt.arange(3)[:, None]
12-
ls = pt.ones(())
13-
cov = ExpQuad(x, ls=ls).eval()
14-
expected_distance = np.array([[0.0, 1.0, 4.0], [1.0, 0.0, 1.0], [4.0, 1.0, 0.0]])
15-
16-
np.testing.assert_allclose(cov, np.exp(-0.5 * expected_distance))
17-
18-
19-
# @pytest.fixture(scope="session")
20-
def latent_model():
10+
def build_latent_model():
2111
with pm.Model() as m:
2212
X = pm.Data("X", np.arange(3)[:, None])
2313
y = np.full(3, np.pi)
@@ -31,7 +21,7 @@ def latent_model():
3121
return m
3222

3323

34-
def latent_model_old_API():
24+
def build_latent_model_old_API():
3525
with pm.Model() as m:
3626
X = pm.Data("X", np.arange(3)[:, None])
3727
y = np.full(3, np.pi)
@@ -46,9 +36,18 @@ def latent_model_old_API():
4636
return m, gp_class
4737

4838

39+
def test_exp_quad():
40+
x = pt.arange(3)[:, None]
41+
ls = pt.ones(())
42+
cov = ExpQuad(x, ls=ls).eval()
43+
expected_distance = np.array([[0.0, 1.0, 4.0], [1.0, 0.0, 1.0], [4.0, 1.0, 0.0]])
44+
45+
np.testing.assert_allclose(cov, np.exp(-0.5 * expected_distance))
46+
47+
4948
def test_latent_model_prior():
50-
m = latent_model()
51-
ref_m, _ = latent_model_old_API()
49+
m = build_latent_model()
50+
ref_m, _ = build_latent_model_old_API()
5251

5352
prior = pm.draw(m["gp"], draws=1000)
5453
prior_ref = pm.draw(ref_m["gp"], draws=1000)
@@ -67,10 +66,10 @@ def test_latent_model_prior():
6766

6867

6968
def test_latent_model_logp():
70-
m = latent_model()
69+
m = build_latent_model()
7170
ip = m.initial_point()
7271

73-
ref_m, _ = latent_model_old_API()
72+
ref_m, _ = build_latent_model_old_API()
7473

7574
np.testing.assert_allclose(
7675
m.compile_logp()(ip),
@@ -89,7 +88,7 @@ def test_latent_model_conditional(inline):
8988

9089
new_x = np.array([3, 4])[:, None]
9190

92-
m = latent_model()
91+
m = build_latent_model()
9392
with m:
9493
pm.Deterministic("gp_exp", m["gp"].exp())
9594

@@ -100,7 +99,7 @@ def test_latent_model_conditional(inline):
10099
progressbar=False,
101100
).posterior_predictive
102101

103-
ref_m, ref_gp_class = latent_model_old_API()
102+
ref_m, ref_gp_class = build_latent_model_old_API()
104103
with ref_m:
105104
gp_star = ref_gp_class.conditional("gp_star", Xnew=new_x)
106105
pred_ref = pm.sample_posterior_predictive(

0 commit comments

Comments
 (0)