Skip to content

Commit 3feaf2e

Browse files
Rough implementation of new API
1 parent 1389a8b commit 3feaf2e

File tree

2 files changed

+201
-65
lines changed

2 files changed

+201
-65
lines changed

pymc_experimental/gp/pytensor_gp.py

Lines changed: 141 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -1,92 +1,168 @@
1-
import numpy as np
21
import pymc as pm
3-
import pytensor
42
import pytensor.tensor as pt
53

6-
from pymc.logprob.abstract import MeasurableVariable, _get_measurable_outputs
7-
from pytensor.graph.op import Apply, Op
4+
from numpy.core.numeric import normalize_axis_tuple
5+
from pymc.distributions.distribution import Continuous
6+
from pytensor.compile.builders import OpFromGraph
7+
from pytensor.tensor.einsum import _delta
88

9+
# from pymc.logprob.abstract import MeasurableOp
910

10-
class Cov(Op):
11-
__props__ = ("fn",)
1211

13-
def __init__(self, fn):
14-
self.fn = fn
12+
class GPCovariance(OpFromGraph):
13+
"""OFG representing a GP covariance"""
1514

16-
def make_node(self, ls):
17-
ls = pt.as_tensor(ls)
18-
out = pt.matrix(shape=(None, None))
19-
20-
return Apply(self, [ls], [out])
21-
22-
def __call__(self, ls=1.0):
23-
return super().__call__(ls)
24-
25-
def perform(self, node, inputs, output_storage):
26-
raise NotImplementedError("You should convert Cov into a TensorVariable expression!")
27-
28-
def do_constant_folding(self, fgraph, node):
29-
return False
15+
@staticmethod
16+
def square_dist(X, ls):
17+
X = X / ls
18+
X2 = pt.sum(pt.square(X), axis=-1)
19+
sqd = -2.0 * X @ X.mT + (X2[..., :, None] + X2[..., None, :])
3020

21+
return sqd
3122

32-
class GP(Op):
33-
__props__ = ("approx",)
3423

35-
def __init__(self, approx):
36-
self.approx = approx
24+
class ExpQuadCov(GPCovariance):
25+
"""
26+
ExpQuad covariance function
27+
"""
3728

38-
def make_node(self, mean, cov):
39-
mean = pt.as_tensor(mean)
40-
cov = pt.as_tensor(cov)
41-
42-
if not (cov.owner and isinstance(cov.owner.op, Cov)):
43-
raise ValueError("Second argument should be a Cov output.")
44-
45-
out = pt.vector(shape=(None,))
29+
@classmethod
30+
def exp_quad_full(cls, X, ls):
31+
return pt.exp(-0.5 * cls.square_dist(X, ls))
4632

47-
return Apply(self, [mean, cov], [out])
33+
@classmethod
34+
def build_covariance(cls, X, ls):
35+
X = pt.as_tensor(X)
36+
ls = pt.as_tensor(ls)
4837

49-
def perform(self, node, inputs, output_storage):
50-
raise NotImplementedError("You cannot evaluate a GP, not enough RAM in the Universe.")
38+
ofg = cls(inputs=[X, ls], outputs=[cls.exp_quad_full(X, ls)])
39+
return ofg(X, ls)
5140

52-
def do_constant_folding(self, fgraph, node):
53-
return False
5441

42+
def ExpQuad(X, ls):
43+
return ExpQuadCov.build_covariance(X, ls)
5544

56-
class PriorFromGP(Op):
57-
"""This Op will be replaced by the right MvNormal."""
5845

59-
def make_node(self, gp, x, rng):
60-
gp = pt.as_tensor(gp)
61-
if not (gp.owner and isinstance(gp.owner.op, GP)):
62-
raise ValueError("First argument should be a GP output.")
46+
class WhiteNoiseCov(GPCovariance):
47+
@classmethod
48+
def white_noise_full(cls, X, sigma):
49+
X_shape = tuple(X.shape)
50+
shape = X_shape[:-1] + (X_shape[-2],)
6351

64-
# TODO: Assert RNG has the right type
65-
x = pt.as_tensor(x)
66-
out = x.type()
52+
return _delta(shape, normalize_axis_tuple((-1, -2), X.ndim)) * sigma**2
6753

68-
return Apply(self, [gp, x, rng], [out])
54+
@classmethod
55+
def build_covariance(cls, X, sigma):
56+
X = pt.as_tensor(X)
57+
sigma = pt.as_tensor(sigma)
6958

70-
def __call__(self, gp, x, rng=None):
71-
if rng is None:
72-
rng = pytensor.shared(np.random.default_rng())
73-
return super().__call__(gp, x, rng)
59+
ofg = cls(inputs=[X, sigma], outputs=[cls.white_noise_full(X, sigma)])
60+
return ofg(X, sigma)
7461

75-
def perform(self, node, inputs, output_storage):
76-
raise NotImplementedError("You should convert PriorFromGP into a MvNormal!")
7762

78-
def do_constant_folding(self, fgraph, node):
79-
return False
63+
def WhiteNoise(X, sigma):
64+
return WhiteNoiseCov.build_covariance(X, sigma)
8065

8166

82-
cov_op = Cov(fn=pm.gp.cov.ExpQuad)
83-
gp_op = GP("vanilla")
84-
# SymbolicRandomVariable.register(type(gp_op))
85-
prior_from_gp = PriorFromGP()
67+
class GP_RV(pm.MvNormal.rv_type):
68+
name = "gaussian_process"
69+
signature = "(n),(n,n)->(n)"
70+
dtype = "floatX"
71+
_print_name = ("GP", "\\operatorname{GP}")
8672

87-
MeasurableVariable.register(type(prior_from_gp))
8873

74+
class GP(Continuous):
75+
rv_type = GP_RV
76+
rv_op = GP_RV()
8977

90-
@_get_measurable_outputs.register(type(prior_from_gp))
91-
def gp_measurable_outputs(op, node):
92-
return node.outputs
78+
@classmethod
79+
def dist(cls, cov, **kwargs):
80+
cov = pt.as_tensor(cov)
81+
mu = pt.zeros(cov.shape[-1])
82+
return super().dist([mu, cov], **kwargs)
83+
84+
85+
# @register_canonicalize
86+
# @node_rewriter(tracks=[pm.MvNormal.rv_type])
87+
# def GP_normal_mvnormal_conjugacy(fgraph: FunctionGraph, node):
88+
# # TODO: Should this alert users that it can't be applied when the GP is in a deterministic?
89+
# gp_rng, gp_size, mu, cov = node.inputs
90+
# next_gp_rng, gp_rv = node.outputs
91+
#
92+
# if not isinstance(cov.owner.op, GPCovariance):
93+
# return
94+
#
95+
# for client, input_index in fgraph.clients[gp_rv]:
96+
# # input_index is 2 because it goes (rng, size, mu, sigma), and we want the mu
97+
# # to be the GP we're looking
98+
# if isinstance(client.op, pm.Normal.rv_type) and (input_index == 2):
99+
# next_normal_rng, normal_rv = client.outputs
100+
# normal_rng, normal_size, mu, sigma = client.inputs
101+
#
102+
# if normal_rv.ndim != gp_rv.ndim:
103+
# return
104+
#
105+
# X = cov.owner.inputs[0]
106+
#
107+
# white_noise = WhiteNoiseCov.build_covariance(X, sigma)
108+
# white_noise.name = 'WhiteNoiseCov'
109+
# cov = cov + white_noise
110+
#
111+
# if not rv_size_is_none(normal_size):
112+
# normal_size = tuple(normal_size)
113+
# new_gp_size = normal_size[:-1]
114+
# core_shape = normal_size[-1]
115+
#
116+
# cov_shape = (*(None,) * (cov.ndim - 2), core_shape, core_shape)
117+
# cov = pt.specify_shape(cov, cov_shape)
118+
#
119+
# else:
120+
# new_gp_size = None
121+
#
122+
# next_new_gp_rng, new_gp_mvn = pm.MvNormal.dist(cov=cov, rng=gp_rng, size=new_gp_size).owner.outputs
123+
# new_gp_mvn.name = 'NewGPMvn'
124+
#
125+
# # Check that the new shape is at least as specific as the shape we are replacing
126+
# for new_shape, old_shape in zip(new_gp_mvn.type.shape, normal_rv.type.shape, strict=True):
127+
# if new_shape is None:
128+
# assert old_shape is None
129+
#
130+
# return {
131+
# next_normal_rng: next_new_gp_rng,
132+
# normal_rv: new_gp_mvn,
133+
# next_gp_rng: next_new_gp_rng
134+
# }
135+
#
136+
# else:
137+
# return None
138+
#
139+
# #TODO: Why do I need to register this twice?
140+
# specialization_ir_rewrites_db.register(
141+
# GP_normal_mvnormal_conjugacy.__name__,
142+
# GP_normal_mvnormal_conjugacy,
143+
# "basic",
144+
# )
145+
146+
# @node_rewriter(tracks=[pm.MvNormal.rv_type])
147+
# def GP_normal_marginal_logp(fgraph: FunctionGraph, node):
148+
# """
149+
# Replace Normal(GP(cov), sigma) -> MvNormal(0, cov + diag(sigma)).
150+
# """
151+
# rng, size, mu, cov = node.inputs
152+
# if cov.owner and cov.owner.op == matrix_inverse:
153+
# tau = cov.owner.inputs[0]
154+
# return PrecisionMvNormalRV.rv_op(mu, tau, size=size, rng=rng).owner.outputs
155+
# return None
156+
#
157+
158+
# cov_op = GPCovariance()
159+
# gp_op = GP("vanilla")
160+
# # SymbolicRandomVariable.register(type(gp_op))
161+
# prior_from_gp = PriorFromGP()
162+
#
163+
# MeasurableVariable.register(type(prior_from_gp))
164+
#
165+
#
166+
# @_get_measurable_outputs.register(type(prior_from_gp))
167+
# def gp_measurable_outputs(op, node):
168+
# return node.outputs

tests/test_gp.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
import numpy as np
2+
import pymc as pm
3+
import pytensor.tensor as pt
4+
import pytest
5+
6+
from pymc_experimental.gp.pytensor_gp import GP, ExpQuad
7+
8+
9+
def test_exp_quad():
10+
x = pt.arange(3)[:, None]
11+
ls = pt.ones(())
12+
cov = ExpQuad.build_covariance(x, ls).eval()
13+
expected_distance = np.array([[0.0, 1.0, 4.0], [1.0, 0.0, 1.0], [4.0, 1.0, 0.0]])
14+
15+
np.testing.assert_allclose(cov, np.exp(-0.5 * expected_distance))
16+
17+
18+
@pytest.fixture(scope="session")
19+
def marginal_model():
20+
with pm.Model() as m:
21+
X = pm.Data("X", np.arange(3)[:, None])
22+
y = np.full(3, np.pi)
23+
ls = 1.0
24+
cov = ExpQuad(X, ls)
25+
gp = GP("gp", cov=cov)
26+
27+
sigma = 1.0
28+
obs = pm.Normal("obs", mu=gp, sigma=sigma, observed=y)
29+
30+
return m
31+
32+
33+
def test_marginal_sigma_rewrites_to_white_noise_cov(marginal_model):
34+
obs = marginal_model["obs"]
35+
36+
# TODO: Bring these checks back after we implement marginalization of the GP RV
37+
#
38+
# assert sum(isinstance(var.owner.op, pm.Normal.rv_type)
39+
# for var in ancestors([obs])
40+
# if var.owner is not None) == 1
41+
#
42+
f = pm.compile_pymc([], obs)
43+
#
44+
# assert not any(isinstance(node.op, pm.Normal.rv_type) for node in f.maker.fgraph.apply_nodes)
45+
46+
draws = np.stack([f() for _ in range(10_000)])
47+
empirical_cov = np.cov(draws.T)
48+
49+
expected_distance = np.array([[0.0, 1.0, 4.0], [1.0, 0.0, 1.0], [4.0, 1.0, 0.0]])
50+
51+
np.testing.assert_allclose(
52+
empirical_cov, np.exp(-0.5 * expected_distance) + np.eye(3), atol=0.1, rtol=0.1
53+
)
54+
55+
56+
def test_marginal_gp_logp(marginal_model):
57+
expected_logps = {"obs": -8.8778}
58+
point_logps = marginal_model.point_logps(round_vals=4)
59+
for v1, v2 in zip(point_logps.values(), expected_logps.values()):
60+
np.testing.assert_allclose(v1, v2, atol=1e-6)

0 commit comments

Comments
 (0)