1
+ from collections .abc import Sequence
2
+
1
3
import pymc as pm
2
4
import pytensor .tensor as pt
3
5
4
- from numpy .core .numeric import normalize_axis_tuple
5
6
from pymc .distributions .distribution import Continuous
7
+ from pymc .model .fgraph import fgraph_from_model , model_free_rv , model_from_fgraph
8
+ from pytensor import Variable
6
9
from pytensor .compile .builders import OpFromGraph
7
- from pytensor .tensor .einsum import _delta
8
-
9
- # from pymc.logprob.abstract import MeasurableOp
10
10
11
11
12
12
class GPCovariance (OpFromGraph ):
@@ -23,7 +23,7 @@ def square_dist_Xs(X, Xs, ls):
23
23
X2 = pt .sum (pt .square (X ), axis = - 1 )
24
24
Xs2 = pt .sum (pt .square (Xs ), axis = - 1 )
25
25
26
- sqd = - 2.0 * X @ X .mT + (X2 [..., :, None ] + Xs2 [..., None , :])
26
+ sqd = - 2.0 * X @ Xs .mT + (X2 [..., :, None ] + Xs2 [..., None , :])
27
27
# sqd = -2.0 * pt.dot(X, pt.transpose(Xs)) + (
28
28
# pt.reshape(X2, (-1, 1)) + pt.reshape(Xs2, (1, -1))
29
29
# )
@@ -68,25 +68,26 @@ def ExpQuad(X, X_new=None, *, ls):
68
68
return ExpQuadCov .build_covariance (X , X_new , ls = ls )
69
69
70
70
71
- class WhiteNoiseCov (GPCovariance ):
72
- @classmethod
73
- def white_noise_full (cls , X , sigma ):
74
- X_shape = tuple (X .shape )
75
- shape = X_shape [:- 1 ] + (X_shape [- 2 ],)
76
-
77
- return _delta (shape , normalize_axis_tuple ((- 1 , - 2 ), X .ndim )) * sigma ** 2
78
-
79
- @classmethod
80
- def build_covariance (cls , X , sigma ):
81
- X = pt .as_tensor (X )
82
- sigma = pt .as_tensor (sigma )
83
-
84
- ofg = cls (inputs = [X , sigma ], outputs = [cls .white_noise_full (X , sigma )])
85
- return ofg (X , sigma )
86
-
71
+ # class WhiteNoiseCov(GPCovariance):
72
+ # @classmethod
73
+ # def white_noise_full(cls, X, sigma):
74
+ # X_shape = tuple(X.shape)
75
+ # shape = X_shape[:-1] + (X_shape[-2],)
76
+ #
77
+ # return _delta(shape, normalize_axis_tuple((-1, -2), X.ndim)) * sigma**2
78
+ #
79
+ # @classmethod
80
+ # def build_covariance(cls, X, sigma):
81
+ # X = pt.as_tensor(X)
82
+ # sigma = pt.as_tensor(sigma)
83
+ #
84
+ # ofg = cls(inputs=[X, sigma], outputs=[cls.white_noise_full(X, sigma)])
85
+ # return ofg(X, sigma)
87
86
88
- def WhiteNoise (X , sigma ):
89
- return WhiteNoiseCov .build_covariance (X , sigma )
87
+ #
88
+ # def WhiteNoise(X, sigma):
89
+ # return WhiteNoiseCov.build_covariance(X, sigma)
90
+ #
90
91
91
92
92
93
class GP_RV (pm .MvNormal .rv_type ):
@@ -108,6 +109,89 @@ def dist(cls, cov, **kwargs):
108
109
return super ().dist ([mu , cov ], ** kwargs )
109
110
110
111
112
+ def conditional_gp (
113
+ model ,
114
+ gp : Variable | str ,
115
+ Xnew ,
116
+ * ,
117
+ jitter = 1e-6 ,
118
+ dims : Sequence [str ] = (),
119
+ inline : bool = False ,
120
+ ):
121
+ """
122
+ Condition a GP on new data.
123
+
124
+ Parameters
125
+ ----------
126
+ model: Model
127
+ gp: Variable | str
128
+ The GP to condition on.
129
+ Xnew: Tensor-like
130
+ New data to condition the GP on.
131
+ jitter: float, default=1e-6
132
+ Jitter to add to the new GP covariance matrix.
133
+ dims: Sequence[str], default=()
134
+ Dimensions of the new GP.
135
+ inline: bool, default=False
136
+ Whether to inline the new GP in place of the old one. This is not always a safe operation.
137
+ If True, any variables that depend on the GP will be updated to depend on the new GP.
138
+
139
+ Returns
140
+ -------
141
+ Conditional model: Model
142
+ A new model with a GP free RV named f"{gp.name}_star" conditioned on the new data.
143
+
144
+ """
145
+
146
+ def _build_conditional (Xnew , f , cov , jitter ):
147
+ if not isinstance (cov .owner .op , GPCovariance ):
148
+ raise NotImplementedError (f"Cannot build conditional of { cov .owner .op } operation" )
149
+ X , ls = cov .owner .inputs
150
+
151
+ Kxx = cov
152
+ Kxs = cov .owner .op .build_covariance (X , Xnew , ls = ls )
153
+ Kss = cov .owner .op .build_covariance (Xnew , ls = ls )
154
+
155
+ L = pt .linalg .cholesky (Kxx + pt .eye (X .shape [0 ]) * jitter )
156
+ # TODO: Use cho_solve
157
+ A = pt .linalg .solve_triangular (L , Kxs , lower = True )
158
+ v = pt .linalg .solve_triangular (L , f , lower = True )
159
+
160
+ mu = (A .mT @ v ).T # Vector?
161
+ cov = Kss - (A .mT @ A )
162
+
163
+ return mu , cov
164
+
165
+ if isinstance (gp , Variable ):
166
+ assert model [gp .name ] is gp
167
+ else :
168
+ gp = model [gp .name ]
169
+
170
+ fgraph , memo = fgraph_from_model (model )
171
+ gp_model_var = memo [gp ]
172
+ gp_rv = gp_model_var .owner .inputs [0 ]
173
+
174
+ if isinstance (gp_rv .owner .op , pm .MvNormal .rv_type ):
175
+ _ , cov = gp_rv .owner .op .dist_params (gp .owner )
176
+ else :
177
+ raise NotImplementedError ("Can only condition on pure GPs" )
178
+
179
+ # TODO: We should write the naive conditional covariance, and then have rewrites that lift it through kernels
180
+ mu_star , cov_star = _build_conditional (Xnew , gp_model_var , cov , jitter )
181
+ gp_rv_star = pm .MvNormal .dist (mu_star , cov_star , name = f"{ gp .name } _star" )
182
+
183
+ value = gp_rv_star .clone ()
184
+ transform = None
185
+ gp_model_var_star = model_free_rv (gp_rv_star , value , transform , * dims )
186
+
187
+ if inline :
188
+ fgraph .replace (gp_model_var , gp_model_var_star , import_missing = True )
189
+ else :
190
+ fgraph .add_output (gp_model_var_star , import_missing = True )
191
+
192
+ return model_from_fgraph (fgraph , mutate_fgraph = True )
193
+
194
+
111
195
# @register_canonicalize
112
196
# @node_rewriter(tracks=[pm.MvNormal.rv_type])
113
197
# def GP_normal_mvnormal_conjugacy(fgraph: FunctionGraph, node):
0 commit comments