Skip to content

Commit e9c1a9c

Browse files
committed
Add inline to conditional transform
1 parent f58327a commit e9c1a9c

File tree

2 files changed

+146
-73
lines changed

2 files changed

+146
-73
lines changed

pymc_experimental/gp/pytensor_gp.py

Lines changed: 107 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
1+
from collections.abc import Sequence
2+
13
import pymc as pm
24
import pytensor.tensor as pt
35

4-
from numpy.core.numeric import normalize_axis_tuple
56
from pymc.distributions.distribution import Continuous
7+
from pymc.model.fgraph import fgraph_from_model, model_free_rv, model_from_fgraph
8+
from pytensor import Variable
69
from pytensor.compile.builders import OpFromGraph
7-
from pytensor.tensor.einsum import _delta
8-
9-
# from pymc.logprob.abstract import MeasurableOp
1010

1111

1212
class GPCovariance(OpFromGraph):
@@ -23,7 +23,7 @@ def square_dist_Xs(X, Xs, ls):
2323
X2 = pt.sum(pt.square(X), axis=-1)
2424
Xs2 = pt.sum(pt.square(Xs), axis=-1)
2525

26-
sqd = -2.0 * X @ X.mT + (X2[..., :, None] + Xs2[..., None, :])
26+
sqd = -2.0 * X @ Xs.mT + (X2[..., :, None] + Xs2[..., None, :])
2727
# sqd = -2.0 * pt.dot(X, pt.transpose(Xs)) + (
2828
# pt.reshape(X2, (-1, 1)) + pt.reshape(Xs2, (1, -1))
2929
# )
@@ -68,25 +68,26 @@ def ExpQuad(X, X_new=None, *, ls):
6868
return ExpQuadCov.build_covariance(X, X_new, ls=ls)
6969

7070

71-
class WhiteNoiseCov(GPCovariance):
72-
@classmethod
73-
def white_noise_full(cls, X, sigma):
74-
X_shape = tuple(X.shape)
75-
shape = X_shape[:-1] + (X_shape[-2],)
76-
77-
return _delta(shape, normalize_axis_tuple((-1, -2), X.ndim)) * sigma**2
78-
79-
@classmethod
80-
def build_covariance(cls, X, sigma):
81-
X = pt.as_tensor(X)
82-
sigma = pt.as_tensor(sigma)
83-
84-
ofg = cls(inputs=[X, sigma], outputs=[cls.white_noise_full(X, sigma)])
85-
return ofg(X, sigma)
86-
71+
# class WhiteNoiseCov(GPCovariance):
72+
# @classmethod
73+
# def white_noise_full(cls, X, sigma):
74+
# X_shape = tuple(X.shape)
75+
# shape = X_shape[:-1] + (X_shape[-2],)
76+
#
77+
# return _delta(shape, normalize_axis_tuple((-1, -2), X.ndim)) * sigma**2
78+
#
79+
# @classmethod
80+
# def build_covariance(cls, X, sigma):
81+
# X = pt.as_tensor(X)
82+
# sigma = pt.as_tensor(sigma)
83+
#
84+
# ofg = cls(inputs=[X, sigma], outputs=[cls.white_noise_full(X, sigma)])
85+
# return ofg(X, sigma)
8786

88-
def WhiteNoise(X, sigma):
89-
return WhiteNoiseCov.build_covariance(X, sigma)
87+
#
88+
# def WhiteNoise(X, sigma):
89+
# return WhiteNoiseCov.build_covariance(X, sigma)
90+
#
9091

9192

9293
class GP_RV(pm.MvNormal.rv_type):
@@ -108,6 +109,89 @@ def dist(cls, cov, **kwargs):
108109
return super().dist([mu, cov], **kwargs)
109110

110111

112+
def conditional_gp(
113+
model,
114+
gp: Variable | str,
115+
Xnew,
116+
*,
117+
jitter=1e-6,
118+
dims: Sequence[str] = (),
119+
inline: bool = False,
120+
):
121+
"""
122+
Condition a GP on new data.
123+
124+
Parameters
125+
----------
126+
model: Model
127+
gp: Variable | str
128+
The GP to condition on.
129+
Xnew: Tensor-like
130+
New data to condition the GP on.
131+
jitter: float, default=1e-6
132+
Jitter to add to the new GP covariance matrix.
133+
dims: Sequence[str], default=()
134+
Dimensions of the new GP.
135+
inline: bool, default=False
136+
Whether to inline the new GP in place of the old one. This is not always a safe operation.
137+
If True, any variables that depend on the GP will be updated to depend on the new GP.
138+
139+
Returns
140+
-------
141+
Conditional model: Model
142+
A new model with a GP free RV named f"{gp.name}_star" conditioned on the new data.
143+
144+
"""
145+
146+
def _build_conditional(Xnew, f, cov, jitter):
147+
if not isinstance(cov.owner.op, GPCovariance):
148+
raise NotImplementedError(f"Cannot build conditional of {cov.owner.op} operation")
149+
X, ls = cov.owner.inputs
150+
151+
Kxx = cov
152+
Kxs = cov.owner.op.build_covariance(X, Xnew, ls=ls)
153+
Kss = cov.owner.op.build_covariance(Xnew, ls=ls)
154+
155+
L = pt.linalg.cholesky(Kxx + pt.eye(X.shape[0]) * jitter)
156+
# TODO: Use cho_solve
157+
A = pt.linalg.solve_triangular(L, Kxs, lower=True)
158+
v = pt.linalg.solve_triangular(L, f, lower=True)
159+
160+
mu = (A.mT @ v).T # Vector?
161+
cov = Kss - (A.mT @ A)
162+
163+
return mu, cov
164+
165+
if isinstance(gp, Variable):
166+
assert model[gp.name] is gp
167+
else:
168+
gp = model[gp.name]
169+
170+
fgraph, memo = fgraph_from_model(model)
171+
gp_model_var = memo[gp]
172+
gp_rv = gp_model_var.owner.inputs[0]
173+
174+
if isinstance(gp_rv.owner.op, pm.MvNormal.rv_type):
175+
_, cov = gp_rv.owner.op.dist_params(gp.owner)
176+
else:
177+
raise NotImplementedError("Can only condition on pure GPs")
178+
179+
# TODO: We should write the naive conditional covariance, and then have rewrites that lift it through kernels
180+
mu_star, cov_star = _build_conditional(Xnew, gp_model_var, cov, jitter)
181+
gp_rv_star = pm.MvNormal.dist(mu_star, cov_star, name=f"{gp.name}_star")
182+
183+
value = gp_rv_star.clone()
184+
transform = None
185+
gp_model_var_star = model_free_rv(gp_rv_star, value, transform, *dims)
186+
187+
if inline:
188+
fgraph.replace(gp_model_var, gp_model_var_star, import_missing=True)
189+
else:
190+
fgraph.add_output(gp_model_var_star, import_missing=True)
191+
192+
return model_from_fgraph(fgraph, mutate_fgraph=True)
193+
194+
111195
# @register_canonicalize
112196
# @node_rewriter(tracks=[pm.MvNormal.rv_type])
113197
# def GP_normal_mvnormal_conjugacy(fgraph: FunctionGraph, node):

tests/test_gp.py

Lines changed: 39 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
1+
import arviz as az
12
import numpy as np
23
import pymc as pm
34
import pytensor.tensor as pt
5+
import pytest
46

5-
from pymc_experimental.gp.pytensor_gp import GP, ExpQuad
7+
from pymc_experimental.gp.pytensor_gp import GP, ExpQuad, conditional_gp
68

79

810
def test_exp_quad():
@@ -77,72 +79,59 @@ def test_latent_model_logp():
7779
)
7880

7981

80-
import arviz as az
81-
82-
83-
def gp_conditional(model, gp, Xnew, jitter=1e-6):
84-
def _build_conditional(self, Xnew, f, cov, jitter):
85-
X, ls = cov.owner.inputs
86-
87-
Kxx = cov
88-
Kxs = cov.owner.op.build_covariance(X, Xnew, ls=ls)
89-
Kss = cov.owner.op.build_covariance(Xnew, ls=ls)
90-
91-
L = pt.linalg.cholesky(Kxx + pt.eye(X.shape[0]) * jitter)
92-
# TODO: Use cho_solve
93-
A = pt.linalg.solve_triangular(L, Kxs, lower=True)
94-
v = pt.linalg.solve_triangular(L, f, lower=True)
95-
96-
mu = (A.mT @ v).T # Vector?
97-
cov = Kss - (A.mT @ A)
98-
99-
return mu, cov
100-
101-
with model.copy() as new_m:
102-
gp = new_m[gp.name]
103-
_, cov = gp.owner.op.dist_params(gp.owner)
104-
mu_star, cov_star = _build_conditional(None, Xnew, gp, cov, jitter)
105-
gp_star = pm.MvNormal("gp_star", mu_star, cov_star)
106-
return new_m
107-
108-
109-
def test_latent_model_predict_new_x():
82+
@pytest.mark.parametrize("inline", (False, True))
83+
def test_latent_model_conditional(inline):
11084
rng = np.random.default_rng(0)
85+
posterior = az.from_dict(
86+
posterior={"gp": rng.normal(np.pi, 1e-3, size=(4, 1000, 3))},
87+
constant_data={"X": np.arange(3)[:, None]},
88+
)
89+
11190
new_x = np.array([3, 4])[:, None]
11291

11392
m = latent_model()
114-
ref_m, ref_gp_class = latent_model_old_API()
93+
with m:
94+
pm.Deterministic("gp_exp", m["gp"].exp())
11595

116-
posterior_idata = az.from_dict({"gp": rng.normal(np.pi, 1e-3, size=(4, 1000, 2))})
117-
118-
# with gp_extend_to_new_x(m):
119-
with gp_conditional(m, m["gp"], new_x):
120-
pred = (
121-
pm.sample_posterior_predictive(posterior_idata, var_names=["gp_star"])
122-
.posterior_predictiev["gp"]
123-
.values
124-
)
96+
with conditional_gp(m, m["gp"], new_x, inline=inline) as cgp:
97+
pred = pm.sample_posterior_predictive(
98+
posterior,
99+
var_names=["gp_star", "gp_exp"],
100+
progressbar=False,
101+
).posterior_predictive
125102

103+
ref_m, ref_gp_class = latent_model_old_API()
126104
with ref_m:
127105
gp_star = ref_gp_class.conditional("gp_star", Xnew=new_x)
128-
pred_ref = (
129-
pm.sample_posterior_predictive(posterior_idata, var_names=["gp_star"])
130-
.posterior_predictive["gp"]
131-
.values
132-
)
106+
pred_ref = pm.sample_posterior_predictive(
107+
posterior,
108+
var_names=["gp_star"],
109+
progressbar=False,
110+
).posterior_predictive
133111

134112
np.testing.assert_allclose(
135-
pred.mean(),
136-
pred_ref.mean(),
113+
pred["gp_star"].mean(),
114+
pred_ref["gp_star"].mean(),
137115
atol=0.1,
138116
)
139117

140118
np.testing.assert_allclose(
141-
pred.std(),
142-
pred_ref.std(),
119+
pred["gp_star"].std(),
120+
pred_ref["gp_star"].std(),
143121
rtol=0.1,
144122
)
145123

124+
if inline:
125+
assert np.testing.assert_allclose(
126+
pred["gp_exp"],
127+
np.exp(pred["gp_star"]),
128+
)
129+
else:
130+
np.testing.assert_allclose(
131+
pred["gp_exp"],
132+
np.exp(posterior.posterior["gp"]),
133+
)
134+
146135

147136
#
148137
# def test_marginal_sigma_rewrites_to_white_noise_cov(marginal_model, ):

0 commit comments

Comments
 (0)