@@ -13,7 +13,7 @@ class GPCovariance(OpFromGraph):
13
13
"""OFG representing a GP covariance"""
14
14
15
15
@staticmethod
16
- def square_dist_Xs (X , Xs , ls ):
16
+ def square_dist (X , Xs , ls ):
17
17
assert X .ndim == 2 , "Complain to Bill about it"
18
18
assert Xs .ndim == 2 , "Complain to Bill about it"
19
19
@@ -24,20 +24,8 @@ def square_dist_Xs(X, Xs, ls):
24
24
Xs2 = pt .sum (pt .square (Xs ), axis = - 1 )
25
25
26
26
sqd = - 2.0 * X @ Xs .mT + (X2 [..., :, None ] + Xs2 [..., None , :])
27
- # sqd = -2.0 * pt.dot(X, pt.transpose(Xs)) + (
28
- # pt.reshape(X2, (-1, 1)) + pt.reshape(Xs2, (1, -1))
29
- # )
30
-
31
27
return pt .clip (sqd , 0 , pt .inf )
32
28
33
- @staticmethod
34
- def square_dist (X , ls ):
35
- X = X / ls
36
- X2 = pt .sum (pt .square (X ), axis = - 1 )
37
- sqd = - 2.0 * X @ X .mT + (X2 [..., :, None ] + X2 [..., None , :])
38
-
39
- return sqd
40
-
41
29
42
30
class ExpQuadCov (GPCovariance ):
43
31
"""
@@ -46,7 +34,7 @@ class ExpQuadCov(GPCovariance):
46
34
47
35
@classmethod
48
36
def exp_quad_full (cls , X , Xs , ls ):
49
- return pt .exp (- 0.5 * cls .square_dist_Xs (X , Xs , ls ))
37
+ return pt .exp (- 0.5 * cls .square_dist (X , Xs , ls ))
50
38
51
39
@classmethod
52
40
def build_covariance (cls , X , Xs = None , * , ls ):
@@ -64,32 +52,10 @@ def build_covariance(cls, X, Xs=None, *, ls):
64
52
return cls (inputs = [X , Xs , ls ], outputs = [out ])(X , Xs , ls )
65
53
66
54
67
- def ExpQuad (X , X_new = None , * , ls ):
55
+ def ExpQuad (X , X_new = None , * , ls = 1.0 ):
68
56
return ExpQuadCov .build_covariance (X , X_new , ls = ls )
69
57
70
58
71
- # class WhiteNoiseCov(GPCovariance):
72
- # @classmethod
73
- # def white_noise_full(cls, X, sigma):
74
- # X_shape = tuple(X.shape)
75
- # shape = X_shape[:-1] + (X_shape[-2],)
76
- #
77
- # return _delta(shape, normalize_axis_tuple((-1, -2), X.ndim)) * sigma**2
78
- #
79
- # @classmethod
80
- # def build_covariance(cls, X, sigma):
81
- # X = pt.as_tensor(X)
82
- # sigma = pt.as_tensor(sigma)
83
- #
84
- # ofg = cls(inputs=[X, sigma], outputs=[cls.white_noise_full(X, sigma)])
85
- # return ofg(X, sigma)
86
-
87
- #
88
- # def WhiteNoise(X, sigma):
89
- # return WhiteNoiseCov.build_covariance(X, sigma)
90
- #
91
-
92
-
93
59
class GP_RV (pm .MvNormal .rv_type ):
94
60
name = "gaussian_process"
95
61
signature = "(n),(n,n)->(n)"
@@ -103,7 +69,6 @@ class GP(Continuous):
103
69
104
70
@classmethod
105
71
def dist (cls , cov , ** kwargs ):
106
- # return Assert(msg="Don't know what a GP_RV is")(False)
107
72
cov = pt .as_tensor (cov )
108
73
mu = pt .zeros (cov .shape [- 1 ])
109
74
return super ().dist ([mu , cov ], ** kwargs )
@@ -190,89 +155,3 @@ def _build_conditional(Xnew, f, cov, jitter):
190
155
fgraph .add_output (gp_model_var_star , import_missing = True )
191
156
192
157
return model_from_fgraph (fgraph , mutate_fgraph = True )
193
-
194
-
195
- # @register_canonicalize
196
- # @node_rewriter(tracks=[pm.MvNormal.rv_type])
197
- # def GP_normal_mvnormal_conjugacy(fgraph: FunctionGraph, node):
198
- # # TODO: Should this alert users that it can't be applied when the GP is in a deterministic?
199
- # gp_rng, gp_size, mu, cov = node.inputs
200
- # next_gp_rng, gp_rv = node.outputs
201
- #
202
- # if not isinstance(cov.owner.op, GPCovariance):
203
- # return
204
- #
205
- # for client, input_index in fgraph.clients[gp_rv]:
206
- # # input_index is 2 because it goes (rng, size, mu, sigma), and we want the mu
207
- # # to be the GP we're looking
208
- # if isinstance(client.op, pm.Normal.rv_type) and (input_index == 2):
209
- # next_normal_rng, normal_rv = client.outputs
210
- # normal_rng, normal_size, mu, sigma = client.inputs
211
- #
212
- # if normal_rv.ndim != gp_rv.ndim:
213
- # return
214
- #
215
- # X = cov.owner.inputs[0]
216
- #
217
- # white_noise = WhiteNoiseCov.build_covariance(X, sigma)
218
- # white_noise.name = 'WhiteNoiseCov'
219
- # cov = cov + white_noise
220
- #
221
- # if not rv_size_is_none(normal_size):
222
- # normal_size = tuple(normal_size)
223
- # new_gp_size = normal_size[:-1]
224
- # core_shape = normal_size[-1]
225
- #
226
- # cov_shape = (*(None,) * (cov.ndim - 2), core_shape, core_shape)
227
- # cov = pt.specify_shape(cov, cov_shape)
228
- #
229
- # else:
230
- # new_gp_size = None
231
- #
232
- # next_new_gp_rng, new_gp_mvn = pm.MvNormal.dist(cov=cov, rng=gp_rng, size=new_gp_size).owner.outputs
233
- # new_gp_mvn.name = 'NewGPMvn'
234
- #
235
- # # Check that the new shape is at least as specific as the shape we are replacing
236
- # for new_shape, old_shape in zip(new_gp_mvn.type.shape, normal_rv.type.shape, strict=True):
237
- # if new_shape is None:
238
- # assert old_shape is None
239
- #
240
- # return {
241
- # next_normal_rng: next_new_gp_rng,
242
- # normal_rv: new_gp_mvn,
243
- # next_gp_rng: next_new_gp_rng
244
- # }
245
- #
246
- # else:
247
- # return None
248
- #
249
- # #TODO: Why do I need to register this twice?
250
- # specialization_ir_rewrites_db.register(
251
- # GP_normal_mvnormal_conjugacy.__name__,
252
- # GP_normal_mvnormal_conjugacy,
253
- # "basic",
254
- # )
255
-
256
- # @node_rewriter(tracks=[pm.MvNormal.rv_type])
257
- # def GP_normal_marginal_logp(fgraph: FunctionGraph, node):
258
- # """
259
- # Replace Normal(GP(cov), sigma) -> MvNormal(0, cov + diag(sigma)).
260
- # """
261
- # rng, size, mu, cov = node.inputs
262
- # if cov.owner and cov.owner.op == matrix_inverse:
263
- # tau = cov.owner.inputs[0]
264
- # return PrecisionMvNormalRV.rv_op(mu, tau, size=size, rng=rng).owner.outputs
265
- # return None
266
- #
267
-
268
- # cov_op = GPCovariance()
269
- # gp_op = GP("vanilla")
270
- # # SymbolicRandomVariable.register(type(gp_op))
271
- # prior_from_gp = PriorFromGP()
272
- #
273
- # MeasurableVariable.register(type(prior_from_gp))
274
- #
275
- #
276
- # @_get_measurable_outputs.register(type(prior_from_gp))
277
- # def gp_measurable_outputs(op, node):
278
- # return node.outputs
0 commit comments