@@ -55,6 +55,8 @@ def marginal_likelihood(self, name, X, *args, **kwargs):
55
55
def conditional (self , name , n_points , Xnew , * args , ** kwargs ):
56
56
raise NotImplementedError
57
57
58
+ def predict (self , Xnew , point = None , given = None , diag = False ):
59
+ raise NotImplementedError
58
60
59
61
@conditioned_vars (["X" , "f" ])
60
62
class Latent (Base ):
@@ -85,33 +87,23 @@ def _get_cond_vals(self, other=None):
85
87
else :
86
88
return other .X , other .f , other .cov_func , other .mean_func
87
89
88
- def _build_conditional (self , Xnew , X , f , cov_total , mean_total , diag = False ):
90
+ def _build_conditional (self , Xnew , X , f , cov_total , mean_total ):
89
91
Kxx = cov_total (X )
90
92
Kxs = self .cov_func (X , Xnew )
91
93
L = cholesky (stabilize (Kxx ))
92
94
A = solve_lower (L , Kxs )
93
95
v = solve_lower (L , f - mean_total (X ))
94
96
mu = self .mean_func (Xnew ) + tt .dot (tt .transpose (A ), v )
95
- if diag :
96
- Kss = self .cov_func (Xnew , diag = True )
97
- cov = Kss - tt .sum (tt .square (A ), 0 )
98
- else :
99
- Kss = self .cov_func (Xnew )
100
- cov = Kss - tt .dot (tt .transpose (A ), A )
97
+ Kss = self .cov_func (Xnew )
98
+ cov = Kss - tt .dot (tt .transpose (A ), A )
101
99
return mu , cov
102
100
103
- def conditional (self , name , n_points , Xnew , gp = None ):
104
- X , f , cov_total , mean_total = self ._get_cond_vals (gp )
101
+ def conditional (self , name , n_points , Xnew , given = None ):
102
+ X , f , cov_total , mean_total = self ._get_cond_vals (given )
105
103
mu , cov = self ._build_conditional (Xnew , X , f , cov_total , mean_total )
106
104
chol = cholesky (stabilize (cov ))
107
105
return pm .MvNormal (name , mu = mu , chol = chol , shape = n_points )
108
106
109
- def predict (self , Xnew , point = None , gp = None , diag = False ):
110
- X , f , cov_total , mean_total = self ._get_cond_vals (gp )
111
- mu , cov = self ._build_conditional (Xnew , X , f , cov_total , mean_total , diag )
112
- mu , cov = draw_values ([mu , cov ], point = point )
113
- return mu , cov
114
-
115
107
116
108
@conditioned_vars (["X" , "f" , "nu" ])
117
109
class TP (Latent ):
@@ -127,45 +119,42 @@ def __init__(self, mean_func=None, cov_func=None, nu=None):
127
119
def __add__ (self , other ):
128
120
raise ValueError ("Student T processes aren't additive" )
129
121
130
- def _build_prior (self , name , n_points , X , nu ):
122
+ def _build_prior (self , name , n_points , X , reparameterize = True ):
131
123
mu = self .mean_func (X )
132
124
chol = cholesky (stabilize (self .cov_func (X )))
133
-
134
- chi2 = pm .ChiSquared ("chi2_" , nu )
135
- v = pm .Normal (name + "_rotated_" , mu = 0.0 , sd = 1.0 , shape = n_points )
136
- f = pm .Deterministic (name , (tt .sqrt (nu ) / chi2 ) * (mu + tt .dot (chol , v )))
125
+ if reparameterize :
126
+ chi2 = pm .ChiSquared ("chi2_" , self .nu )
127
+ v = pm .Normal (name + "_rotated_" , mu = 0.0 , sd = 1.0 , shape = n_points )
128
+ f = pm .Deterministic (name , (tt .sqrt (self .nu ) / chi2 ) * (mu + tt .dot (chol , v )))
129
+ else :
130
+ f = pm .MvStudentT (name , nu = self .nu , mu = mu , chol = chol , shape = n_points )
137
131
return f
138
132
139
- def prior (self , name , n_points , X , nu ):
140
- f = self ._build_prior (name , n_points , X , nu )
133
+ def prior (self , name , n_points , X , reparameterize = True ):
134
+ f = self ._build_prior (name , n_points , X , reparameterize )
141
135
self .X = X
142
- self .nu = nu
143
136
self .f = f
144
137
return f
145
138
146
- def _build_conditional (self , Xnew , X , f , nu ):
147
- Kxx = self .cov_total (X )
139
+ def _build_conditional (self , Xnew , X , f ):
140
+ Kxx = self .cov_func (X )
148
141
Kxs = self .cov_func (X , Xnew )
149
142
Kss = self .cov_func (Xnew )
150
143
L = cholesky (stabilize (Kxx ))
151
144
A = solve_lower (L , Kxs )
152
145
cov = Kss - tt .dot (tt .transpose (A ), A )
153
-
154
- v = solve_lower (L , f - self .mean_total (X ))
146
+ v = solve_lower (L , f - self .mean_func (X ))
155
147
mu = self .mean_func (Xnew ) + tt .dot (tt .transpose (A ), v )
156
-
157
148
beta = tt .dot (v , v )
158
- nu2 = nu + X .shape [0 ]
159
-
160
- covT = (nu + beta - 2 )/ (nu2 - 2 ) * cov
149
+ nu2 = self .nu + X .shape [0 ]
150
+ covT = (self .nu + beta - 2 )/ (nu2 - 2 ) * cov
151
+ return nu2 , mu , covT
152
+
153
+ def conditional (self , name , n_points , Xnew ):
154
+ X = self .X
155
+ f = self .f
156
+ nu2 , mu , covT = self ._build_conditional (Xnew , X , f )
161
157
chol = cholesky (stabilize (covT ))
162
- return nu2 , mu , chol
163
-
164
- def conditional (self , name , n_points , Xnew , X = None , f = None , nu = None ):
165
- if X is None : X = self .X
166
- if f is None : f = self .f
167
- if nu is None : nu = self .nu
168
- nu2 , mu , chol = self ._build_conditional (Xnew , X , f , nu )
169
158
return pm .MvStudentT (name , nu = nu2 , mu = mu , chol = chol , shape = n_points )
170
159
171
160
@@ -226,16 +215,16 @@ def _get_cond_vals(self, other=None):
226
215
else :
227
216
return other .X , other .y , other .noise , other .cov_func , other .mean_func
228
217
229
- def conditional (self , name , n_points , Xnew , gp = None , pred_noise = False ):
218
+ def conditional (self , name , n_points , Xnew , given = None , pred_noise = False ):
230
219
# try to get n_points from X, (via cast to int?), error if cant and n_points is none
231
- X , y , noise , cov_total , mean_total = self ._get_cond_vals (gp )
220
+ X , y , noise , cov_total , mean_total = self ._get_cond_vals (given )
232
221
mu , cov = self ._build_conditional (Xnew , X , y , noise , cov_total , mean_total ,
233
222
pred_noise , diag = False )
234
223
chol = cholesky (cov )
235
224
return pm .MvNormal (name , mu = mu , chol = chol , shape = n_points )
236
225
237
- def predict (self , Xnew , point = None , gp = None , pred_noise = False , diag = False ):
238
- X , y , noise , cov_total , mean_total = self ._get_cond_vals (gp )
226
+ def predict (self , Xnew , point = None , given = None , pred_noise = False , diag = False ):
227
+ X , y , noise , cov_total , mean_total = self ._get_cond_vals (given )
239
228
mu , cov = self ._build_conditional (Xnew , X , y , noise , cov_total , mean_total ,
240
229
pred_noise , diag )
241
230
mu , cov = draw_values ([mu , cov ], point = point )
@@ -291,7 +280,7 @@ def _build_marginal_likelihood_logp(self, X, Xu, y, sigma):
291
280
quadratic = 0.5 * (tt .dot (r , r_l ) - tt .dot (c , c ))
292
281
return - 1.0 * (constant + logdet + quadratic + trace )
293
282
294
- def marginal_likelihood (self , name , n_points , X , Xu , y , sigma , is_observed = True ):
283
+ def marginal_likelihood (self , name , X , Xu , y , sigma , n_points = None , is_observed = True ):
295
284
self .X = X
296
285
self .Xu = Xu
297
286
self .y = y
@@ -300,49 +289,66 @@ def marginal_likelihood(self, name, n_points, X, Xu, y, sigma, is_observed=True)
300
289
if is_observed : # same thing ith n_points here?? check
301
290
return pm .DensityDist (name , logp , observed = y )
302
291
else :
303
- return pm .DensityDist (name , logp , size = n_points ) # need size? if not, dont need size arg
292
+ if n_points is None :
293
+ raise ValueError ("When `y` is not observed, `n_points` arg is required" )
294
+ return pm .DensityDist (name , logp , size = n_points ) # not, dont need size arg
304
295
305
- def _build_conditional (self , Xnew , Xu , X , y , sigma , pred_noise ):
296
+ def _build_conditional (self , Xnew , Xu , X , y , sigma , cov_total , mean_total ,
297
+ pred_noise , diag = False ):
306
298
sigma2 = tt .square (sigma )
307
- Kuu = self . cov_func (Xu )
308
- Kuf = self . cov_func (Xu , X )
299
+ Kuu = cov_total (Xu )
300
+ Kuf = cov_total (Xu , X )
309
301
Luu = cholesky (stabilize (Kuu ))
310
302
A = solve_lower (Luu , Kuf )
311
303
Qffd = tt .sum (A * A , 0 )
312
304
if self .approx not in self ._available_approx :
313
305
raise NotImplementedError (self .approx )
314
306
elif self .approx == "FITC" :
315
- Kffd = self . cov_func (X , diag = True )
307
+ Kffd = cov_total (X , diag = True )
316
308
Lamd = tt .clip (Kffd - Qffd , 0.0 , np .inf ) + sigma2
317
309
else : # VFE or DTC
318
310
Lamd = tt .ones_like (Qffd ) * sigma2
319
311
A_l = A / Lamd
320
312
L_B = cholesky (tt .eye (Xu .shape [0 ]) + tt .dot (A_l , tt .transpose (A )))
321
- r = y - self . mean_func (X )
313
+ r = y - mean_total (X )
322
314
r_l = r / Lamd
323
315
c = solve_lower (L_B , tt .dot (A , r_l ))
324
316
Kus = self .cov_func (Xu , Xnew )
325
317
As = solve_lower (Luu , Kus )
326
- mean = (self .mean_func (Xnew ) +
327
- tt .dot (tt .transpose (As ), solve_upper (tt .transpose (L_B ), c )))
318
+ mu = self .mean_func (Xnew ) + tt .dot (tt .transpose (As ), solve_upper (tt .transpose (L_B ), c ))
328
319
C = solve_lower (L_B , As )
329
- if pred_noise :
330
- cov = (self .cov_func (Xnew ) - tt .dot (tt .transpose (As ), As ) +
331
- tt .dot (tt .transpose (C ), C ) + sigma2 * tt .eye (Xnew .shape [0 ]))
320
+ if diag :
321
+ Kss = self .cov_func (Xnew , diag = True )
322
+ var = Kss - tt .sum (tt .sqaure (As ), 0 ) + tt .sum (tt .square (C ), 0 )
323
+ if pred_noise :
324
+ var += sigma2
325
+ return mu , var
332
326
else :
333
327
cov = (self .cov_func (Xnew ) - tt .dot (tt .transpose (As ), As ) +
334
328
tt .dot (tt .transpose (C ), C ))
335
- return mean , stabilize (cov )
336
-
337
- def conditional (self , name , n_points , Xnew , Xu = None , X = None , y = None ,
338
- sigma = None , pred_noise = False ):
339
- if Xu is None : Xu = self .Xu
340
- if X is None : X = self .X
341
- if y is None : y = self .y
342
- if sigma is None : sigma = self .sigma
343
- mu , chol = self ._build_conditional (Xnew , Xu , X , y , sigma , pred_noise )
344
- return pm .MvNormal (name , mu = mu , chol = chol , shape = n_points )
329
+ if pred_noise :
330
+ cov += sigma2 * tt .identity_like (cov )
331
+ return mu , stabilize (cov )
332
+
333
+ def _get_cond_vals (self , other = None ):
334
+ if other is None :
335
+ return self .X , self .Xu , self .y , self .sigma , self .cov_func , self .mean_func ,
336
+ else :
337
+ return other .X , self .Xu , other .y , other .sigma , other .cov_func , other .mean_func
345
338
339
+ def conditional (self , name , n_points , Xnew , given = None , pred_noise = False ):
340
+ # try to get n_points from X, (via cast to int?), error if cant and n_points is none
341
+ X , Xu , y , sigma , cov_total , mean_total = self ._get_cond_vals (given )
342
+ mu , cov = self ._build_conditional (Xnew , Xu , X , y , sigma , cov_total , mean_total ,
343
+ pred_noise , diag = False )
344
+ chol = cholesky (cov )
345
+ return pm .MvNormal (name , mu = mu , chol = chol , shape = n_points )
346
346
347
+ def predict (self , Xnew , point = None , given = None , pred_noise = False , diag = False ):
348
+ X , Xu , y , sigma , cov_total , mean_total = self ._get_cond_vals (given )
349
+ mu , cov = self ._build_conditional (Xnew , Xu , X , y , sigma , cov_total , mean_total ,
350
+ pred_noise , diag )
351
+ mu , cov = draw_values ([mu , cov ], point = point )
352
+ return mu , cov
347
353
348
354
0 commit comments