@@ -193,31 +193,28 @@ def _prior_rv(self, X, y, sigma=None, cov_func_noise=None):
193
193
194
194
def _predictive_rv (self , X , y , Xs , sigma = None , cov_func_noise = None ):
195
195
cov_func_noise = self ._to_noise_func (sigma , cov_func_noise )
196
- mu , chol = self ._build_predictive (X , Xs , y , cov_func_noise )
196
+ mu , chol = self ._build_predictive (X , y , Xs , cov_func_noise )
197
197
return pm .MvNormal (self .name , mu = mu , chol = chol , shape = self .size )
198
198
199
199
200
200
201
201
class GPMarginalSparse (GPBase ):
202
202
""" FITC and VFE sparse approximations
203
203
"""
204
- def __init__ (self , cov_func ):
204
+ def __init__ (self , cov_func , approx = None ):
205
+ if approx is None :
206
+ approx = "FITC"
207
+ self .approx = approx
205
208
super (GPMarginalSparse , self ).__init__ (cov_func )
206
209
207
- def __call__ (self , name , size , mean_func , include_noise = False , approx = None ):
208
- if hasattr (self , "approx" ) and self .approx != approx :
209
- raise ValueError ("dont use diff approx (should this be a warning?)" )
210
- else :
211
- if approx is None :
212
- pm ._log .info ("Using FITC approximation for {}" .format (name ))
213
- approx = "FITC"
214
- else :
215
- approx = approx .upper ()
216
- if approx not in ["VFE" , "FITC" ]:
217
- raise ValueError (("'FITC' or 'VFE' are the supported GP "
218
- "approximations, not {}" .format (approx )))
219
- self .approx = approx
210
+ # overriding __add__, since whether its vfe or fitc determines its 'type'
211
+ def __add__ (self , other ):
212
+ if not (isinstance (self , type (other )) and self .approx == other .approx ):
213
+ raise ValueError ("cant add different GP types" )
214
+ cov_func = self .cov_func + other .cov_func
215
+ return type (self )(cov_func )
220
216
217
+ def __call__ (self , name , size , mean_func , include_noise = False ):
221
218
self .include_noise = include_noise
222
219
return super (GPMarginalSparse , self ).__call__ (name , size , mean_func )
223
220
@@ -240,7 +237,7 @@ def kmeans_inducing_points(self, n_inducing, X):
240
237
Xu , distortion = kmeans (Xw , n_inducing )
241
238
return Xu * scaling
242
239
243
- def _build_prior (self , X , Xu , y , sigma ):
240
+ def _build_prior_logp (self , X , Xu , y , sigma ):
244
241
sigma2 = tt .square (sigma )
245
242
Kuu = self .cov_func (Xu )
246
243
Kuf = self .cov_func (Xu , X )
@@ -257,7 +254,7 @@ def _build_prior(self, X, Xu, y, sigma):
257
254
(tt .sum (self .cov_func (X , diag = True )) -
258
255
tt .sum (tt .sum (A * A , 0 ))))
259
256
else :
260
- raise NotImplementedError (approx )
257
+ raise NotImplementedError (self . approx )
261
258
A_l = A / Lamd
262
259
L_B = cholesky (tt .eye (Xu .shape [0 ]) + tt .dot (A_l , tt .transpose (A )))
263
260
r = y - self .mean_func (X )
0 commit comments