7
7
import pymc3 as pm
8
8
from pymc3 .gp .cov import Covariance
9
9
from pymc3 .gp .mean import Constant
10
- from pymc3 .gp .util import conditioned_vars
10
+ from pymc3 .gp .util import (conditioned_vars ,
11
+ infer_shape , stabilize , cholesky , solve , solve_lower , solve_upper )
11
12
from pymc3 .distributions import draw_values
12
13
13
14
__all__ = ['Latent' , 'Marginal' , 'TP' , 'MarginalSparse' ]
14
15
15
16
16
- cholesky = pm .distributions .dist_math .Cholesky (nofail = True , lower = True )
17
- solve_lower = tt .slinalg .Solve (A_structure = 'lower_triangular' )
18
- solve_upper = tt .slinalg .Solve (A_structure = 'upper_triangular' )
19
-
20
- def stabilize (K ):
21
- """ adds small diagonal to a covariance matrix """
22
- return K + 1e-6 * tt .identity_like (K )
23
-
24
-
25
17
class Base (object ):
26
18
"""
27
19
Base class
@@ -58,14 +50,15 @@ def conditional(self, name, n_points, Xnew, *args, **kwargs):
58
50
def predict (self , Xnew , point = None , given = None , diag = False ):
59
51
raise NotImplementedError
60
52
53
+
61
54
@conditioned_vars (["X" , "f" ])
62
55
class Latent (Base ):
63
56
""" Where the GP f isnt integrated out, and is sampled explicitly
64
57
"""
65
58
def __init__ (self , mean_func = None , cov_func = None ):
66
59
super (Latent , self ).__init__ (mean_func , cov_func )
67
60
68
- def _build_prior (self , name , n_points , X , reparameterize = True ):
61
+ def _build_prior (self , name , X , n_points , reparameterize = True ):
69
62
mu = self .mean_func (X )
70
63
chol = cholesky (stabilize (self .cov_func (X )))
71
64
if reparameterize :
@@ -75,17 +68,25 @@ def _build_prior(self, name, n_points, X, reparameterize=True):
75
68
f = pm .MvNormal (name , mu = mu , chol = chol , shape = n_points )
76
69
return f
77
70
78
- def prior (self , name , n_points , X , reparameterize = True ):
79
- f = self ._build_prior (name , n_points , X , reparameterize )
71
+ def prior (self , name , X , n_points = None , reparameterize = True ):
72
+ n_points = infer_shape (X , n_points )
73
+ f = self ._build_prior (name , X , n_points , reparameterize )
80
74
self .X = X
81
75
self .f = f
82
76
return f
83
77
84
- def _get_cond_vals (self , other = None ):
85
- if other is None :
86
- return self .X , self .f , self .cov_func , self .mean_func ,
78
+ def _get_given_vals (self , ** given ):
79
+ if 'gp' in given :
80
+ cov_total = given ['gp' ].cov_func
81
+ mean_total = given ['gp' ].mean_func
82
+ else :
83
+ cov_total = self .cov_func
84
+ mean_total = self .mean_func
85
+ if all (val in given for val in ['X' , 'f' ]):
86
+ X , f = given ['X' ], given ['f' ]
87
87
else :
88
- return other .X , other .f , other .cov_func , other .mean_func
88
+ X , f = self .X , self .f
89
+ return X , f , cov_total , mean_total
89
90
90
91
def _build_conditional (self , Xnew , X , f , cov_total , mean_total ):
91
92
Kxx = cov_total (X )
@@ -98,10 +99,11 @@ def _build_conditional(self, Xnew, X, f, cov_total, mean_total):
98
99
cov = Kss - tt .dot (tt .transpose (A ), A )
99
100
return mu , cov
100
101
101
- def conditional (self , name , n_points , Xnew , given = None ):
102
- X , f , cov_total , mean_total = self ._get_cond_vals ( given )
103
- mu , cov = self ._build_conditional (Xnew , X , f , cov_total , mean_total )
102
+ def conditional (self , name , Xnew , n_points = None , given = None ):
103
+ givens = self ._get_given_vals ( ** given )
104
+ mu , cov = self ._build_conditional (Xnew , * givens )
104
105
chol = cholesky (stabilize (cov ))
106
+ n_points = infer_shape (Xnew , n_points )
105
107
return pm .MvNormal (name , mu = mu , chol = chol , shape = n_points )
106
108
107
109
@@ -150,11 +152,12 @@ def _build_conditional(self, Xnew, X, f):
150
152
covT = (self .nu + beta - 2 )/ (nu2 - 2 ) * cov
151
153
return nu2 , mu , covT
152
154
153
- def conditional (self , name , n_points , Xnew ):
155
+ def conditional (self , name , Xnew , n_points = None ):
154
156
X = self .X
155
157
f = self .f
156
158
nu2 , mu , covT = self ._build_conditional (Xnew , X , f )
157
159
chol = cholesky (stabilize (covT ))
160
+ n_points = infer_shape (Xnew , n_points )
158
161
return pm .MvStudentT (name , nu = nu2 , mu = mu , chol = chol , shape = n_points )
159
162
160
163
@@ -182,10 +185,24 @@ def marginal_likelihood(self, name, X, y, noise, n_points=None, is_observed=True
182
185
if is_observed :
183
186
return pm .MvNormal (name , mu = mu , chol = chol , observed = y )
184
187
else :
185
- if n_points is None :
186
- raise ValueError ("When `y` is not observed, `n_points` arg is required" )
188
+ n_points = infer_shape (X , n_points )
187
189
return pm .MvNormal (name , mu = mu , chol = chol , size = n_points )
188
190
191
+ def _get_given_vals (self , ** given ):
192
+ if 'gp' in given :
193
+ cov_total = given ['gp' ].cov_func
194
+ mean_total = given ['gp' ].mean_func
195
+ else :
196
+ cov_total = self .cov_func
197
+ mean_total = self .mean_func
198
+ if all (val in given for val in ['X' , 'y' , 'noise' ]):
199
+ X , y , noise = given ['X' ], given ['y' ], given ['noise' ]
200
+ if not isinstance (noise , Covariance ):
201
+ noise = pm .gp .cov .WhiteNoise (noise )
202
+ else :
203
+ X , y , noise = self .X , self .y , self .noise
204
+ return X , y , noise , cov_total , mean_total
205
+
189
206
def _build_conditional (self , Xnew , X , y , noise , cov_total , mean_total ,
190
207
pred_noise , diag = False ):
191
208
Kxx = cov_total (X )
@@ -209,38 +226,32 @@ def _build_conditional(self, Xnew, X, y, noise, cov_total, mean_total,
209
226
cov += noise (Xnew )
210
227
return mu , stabilize (cov )
211
228
212
- def _get_cond_vals (self , other = None ):
213
- # LOOK AT THIS MORE, where to X, y, noise need to come from? depending on situation
214
- # provide this function with **kwargs and return those if given? X=X, y=y, etc
215
- # i think this would be good. could build the gp "from scratch". caching these
216
- # from prior or marglike is a convenience
217
- if other is None :
218
- return self .X , self .y , self .noise , self .cov_func , self .mean_func ,
219
- else :
220
- return other .X , other .y , other .noise , other .cov_func , other .mean_func
221
-
222
- def conditional (self , name , n_points , Xnew , given = None , pred_noise = False ):
223
- # try to get n_points from X, (via cast to int?), error if cant and n_points is none
224
- X , y , noise , cov_total , mean_total = self ._get_cond_vals (given )
225
- mu , cov = self ._build_conditional (Xnew , X , y , noise , cov_total , mean_total ,
226
- pred_noise , diag = False )
229
+ def conditional (self , name , Xnew , pred_noise = False , n_points = None , ** given ):
230
+ givens = self ._get_given_vals (** given )
231
+ mu , cov = self ._build_conditional (Xnew , * givens , pred_noise , diag = False )
227
232
chol = cholesky (cov )
233
+ n_points = infer_shape (Xnew , n_points )
228
234
return pm .MvNormal (name , mu = mu , chol = chol , shape = n_points )
229
235
230
- def predict (self , Xnew , point = None , given = None , pred_noise = False , diag = False ):
231
- X , y , noise , cov_total , mean_total = self ._get_cond_vals (given )
232
- mu , cov = self ._build_conditional (Xnew , X , y , noise , cov_total , mean_total ,
233
- pred_noise , diag )
236
+ def predict (self , Xnew , point = None , diag = False , pred_noise = False , ** given ):
237
+ mu , cov = self .predictt (Xnew , diag , pred_noise , ** given )
234
238
mu , cov = draw_values ([mu , cov ], point = point )
235
239
return mu , cov
236
240
241
+ def predictt (self , Xnew , diag = False , pred_noise = False , ** given ):
242
+ givens = self ._get_given_vals (** given )
243
+ mu , cov = self ._build_conditional (Xnew , * givens , pred_noise , diag )
244
+ return mu , cov
245
+
237
246
238
247
@conditioned_vars (["X" , "Xu" , "y" , "sigma" ])
239
- class MarginalSparse (Base ):
248
+ class MarginalSparse (Marginal ):
240
249
_available_approx = ["FITC" , "VFE" , "DTC" ]
241
250
""" FITC and VFE sparse approximations
242
251
"""
243
252
def __init__ (self , mean_func = None , cov_func = None , approx = "FITC" ):
253
+ if approx not in self ._available_approx :
254
+ raise NotImplementedError (approx )
244
255
self .approx = approx
245
256
super (MarginalSparse , self ).__init__ (mean_func , cov_func )
246
257
@@ -260,9 +271,7 @@ def _build_marginal_likelihood_logp(self, X, Xu, y, sigma):
260
271
Luu = cholesky (stabilize (Kuu ))
261
272
A = solve_lower (Luu , Kuf )
262
273
Qffd = tt .sum (A * A , 0 )
263
- if self .approx not in self ._available_approx :
264
- raise NotImplementedError (self .approx )
265
- elif self .approx == "FITC" :
274
+ if self .approx == "FITC" :
266
275
Kffd = self .cov_func (X , diag = True )
267
276
Lamd = tt .clip (Kffd - Qffd , 0.0 , np .inf ) + sigma2
268
277
trace = 0.0
@@ -293,9 +302,8 @@ def marginal_likelihood(self, name, X, Xu, y, sigma, n_points=None, is_observed=
293
302
if is_observed : # same thing ith n_points here?? check
294
303
return pm .DensityDist (name , logp , observed = y )
295
304
else :
296
- if n_points is None :
297
- raise ValueError ("When `y` is not observed, `n_points` arg is required" )
298
- return pm .DensityDist (name , logp , size = n_points ) # not, dont need size arg
305
+ n_points = infer_shape (X , n_points )
306
+ return pm .DensityDist (name , logp , size = n_points )
299
307
300
308
def _build_conditional (self , Xnew , Xu , X , y , sigma , cov_total , mean_total ,
301
309
pred_noise , diag = False ):
@@ -305,9 +313,7 @@ def _build_conditional(self, Xnew, Xu, X, y, sigma, cov_total, mean_total,
305
313
Luu = cholesky (stabilize (Kuu ))
306
314
A = solve_lower (Luu , Kuf )
307
315
Qffd = tt .sum (A * A , 0 )
308
- if self .approx not in self ._available_approx :
309
- raise NotImplementedError (self .approx )
310
- elif self .approx == "FITC" :
316
+ if self .approx == "FITC" :
311
317
Kffd = cov_total (X , diag = True )
312
318
Lamd = tt .clip (Kffd - Qffd , 0.0 , np .inf ) + sigma2
313
319
else : # VFE or DTC
@@ -334,25 +340,17 @@ def _build_conditional(self, Xnew, Xu, X, y, sigma, cov_total, mean_total,
334
340
cov += sigma2 * tt .identity_like (cov )
335
341
return mu , stabilize (cov )
336
342
337
- def _get_cond_vals (self , other = None ):
338
- if other is None :
339
- return self .X , self .Xu , self .y , self .sigma , self .cov_func , self .mean_func ,
343
+ def _get_given_vals (self , ** given ):
344
+ if 'gp' in given :
345
+ cov_total = given ['gp' ].cov_func
346
+ mean_total = given ['gp' ].mean_func
340
347
else :
341
- return other .X , self .Xu , other .y , other .sigma , other .cov_func , other .mean_func
342
-
343
- def conditional (self , name , n_points , Xnew , given = None , pred_noise = False ):
344
- # try to get n_points from X, (via cast to int?), error if cant and n_points is none
345
- X , Xu , y , sigma , cov_total , mean_total = self ._get_cond_vals (given )
346
- mu , cov = self ._build_conditional (Xnew , Xu , X , y , sigma , cov_total , mean_total ,
347
- pred_noise , diag = False )
348
- chol = cholesky (cov )
349
- return pm .MvNormal (name , mu = mu , chol = chol , shape = n_points )
350
-
351
- def predict (self , Xnew , point = None , given = None , pred_noise = False , diag = False ):
352
- X , Xu , y , sigma , cov_total , mean_total = self ._get_cond_vals (given )
353
- mu , cov = self ._build_conditional (Xnew , Xu , X , y , sigma , cov_total , mean_total ,
354
- pred_noise , diag )
355
- mu , cov = draw_values ([mu , cov ], point = point )
356
- return mu , cov
348
+ cov_total = self .cov_func
349
+ mean_total = self .mean_func
350
+ if all (val in given for val in ['X' , 'Xu' , 'y' , 'sigma' ]):
351
+ X , Xu , y , sigma = given ['X' ], given ['Xu' ], given ['y' ], given ['sigma' ]
352
+ else :
353
+ X , Xu , y , sigma = self .X , self .y , self .sigma
354
+ return X , y , sigma , cov_total , mean_total
357
355
358
356
0 commit comments