6
6
7
7
import pymc3 as pm
8
8
from pymc3 .gp .cov import Covariance
9
+ from pymc3 .gp .mean import Constant
9
10
from pymc3 .gp .util import conditioned_vars
11
+ from pymc3 .distributions import draw_values
10
12
11
13
__all__ = ['Latent' , 'Marginal' , 'TP' , 'MarginalSparse' ]
12
14
@@ -20,7 +22,7 @@ def stabilize(K):
20
22
return K + 1e-6 * tt .identity_like (K )
21
23
22
24
23
- class GPBase (object ):
25
+ class Base (object ):
24
26
"""
25
27
Base class
26
28
"""
@@ -36,44 +38,13 @@ def __init__(self, mean_func=None, cov_func=None):
36
38
self .mean_func = mean_func
37
39
self .cov_func = cov_func
38
40
39
- @property
40
- def cov_total (self ):
41
- total = getattr (self , "_cov_total" , None )
42
- if total is None :
43
- return self .cov_func
44
- else :
45
- return total
46
-
47
- @cov_total .setter
48
- def cov_total (self , new_cov_total ):
49
- self ._cov_total = new_cov_total
50
-
51
- @property
52
- def mean_total (self ):
53
- total = getattr (self , "_mean_total" , None )
54
- if total is None :
55
- return self .mean_func
56
- else :
57
- return total
58
-
59
- @mean_total .setter
60
- def mean_total (self , new_mean_total ):
61
- self ._mean_total = new_mean_total
62
-
63
41
def __add__ (self , other ):
64
42
same_attrs = set (self .__dict__ .keys ()) == set (other .__dict__ .keys ())
65
43
if not isinstance (self , type (other )) and not same_attrs :
66
44
raise ValueError ("cant add different GP types" )
67
-
68
- # set cov_func and mean_func of new GP
69
- cov_total = self .cov_func + other .cov_func
70
45
mean_total = self .mean_func + other .mean_func
71
-
72
- # update self and other mean and cov totals
73
- self .cov_total , self .mean_total = (cov_total , mean_total )
74
- other .cov_total , other .mean_total = (cov_total , mean_total )
75
- new_gp = self .__class__ (mean_total , cov_total )
76
- return new_gp
46
+ cov_total = self .cov_func + other .cov_func
47
+ return self .__class__ (mean_total , cov_total )
77
48
78
49
def prior (self , name , X , * args , ** kwargs ):
79
50
raise NotImplementedError
@@ -84,9 +55,11 @@ def marginal_likelihood(self, name, X, *args, **kwargs):
84
55
def conditional (self , name , n_points , Xnew , * args , ** kwargs ):
85
56
raise NotImplementedError
86
57
58
+ def predict (self , Xnew , point = None , given = None , diag = False ):
59
+ raise NotImplementedError
87
60
88
61
@conditioned_vars (["X" , "f" ])
89
- class Latent (GPBase ):
62
+ class Latent (Base ):
90
63
""" Where the GP f isnt integrated out, and is sampled explicitly
91
64
"""
92
65
def __init__ (self , mean_func = None , cov_func = None ):
@@ -100,30 +73,35 @@ def _build_prior(self, name, n_points, X, reparameterize=True):
100
73
f = pm .Deterministic (name , mu + tt .dot (chol , v ))
101
74
else :
102
75
f = pm .MvNormal (name , mu = mu , chol = chol , shape = n_points )
103
- self .X = X
104
- self .f = f
105
76
return f
106
77
107
78
def prior (self , name , n_points , X , reparameterize = True ):
108
79
f = self ._build_prior (name , n_points , X , reparameterize )
80
+ self .X = X
81
+ self .f = f
109
82
return f
110
83
111
- def _build_conditional (self , Xnew , X , f ):
112
- Kxx = self .cov_total (X )
84
+ def _get_cond_vals (self , other = None ):
85
+ if other is None :
86
+ return self .X , self .f , self .cov_func , self .mean_func ,
87
+ else :
88
+ return other .X , other .f , other .cov_func , other .mean_func
89
+
90
+ def _build_conditional (self , Xnew , X , f , cov_total , mean_total ):
91
+ Kxx = cov_total (X )
113
92
Kxs = self .cov_func (X , Xnew )
114
- Kss = self .cov_func (Xnew )
115
93
L = cholesky (stabilize (Kxx ))
116
94
A = solve_lower (L , Kxs )
117
- cov = Kss - tt .dot (tt .transpose (A ), A )
118
- chol = cholesky (stabilize (cov ))
119
- v = solve_lower (L , f - self .mean_total (X ))
95
+ v = solve_lower (L , f - mean_total (X ))
120
96
mu = self .mean_func (Xnew ) + tt .dot (tt .transpose (A ), v )
121
- return mu , chol
97
+ Kss = self .cov_func (Xnew )
98
+ cov = Kss - tt .dot (tt .transpose (A ), A )
99
+ return mu , cov
122
100
123
- def conditional (self , name , n_points , Xnew , X = None , f = None ):
124
- if X is None : X = self .X
125
- if f is None : f = self .f
126
- mu , chol = self . _build_conditional ( Xnew , X , f )
101
+ def conditional (self , name , n_points , Xnew , given = None ):
102
+ X , f , cov_total , mean_total = self ._get_cond_vals ( given )
103
+ mu , cov = self ._build_conditional ( Xnew , X , f , cov_total , mean_total )
104
+ chol = cholesky ( stabilize ( cov ) )
127
105
return pm .MvNormal (name , mu = mu , chol = chol , shape = n_points )
128
106
129
107
@@ -141,58 +119,54 @@ def __init__(self, mean_func=None, cov_func=None, nu=None):
141
119
def __add__ (self , other ):
142
120
raise ValueError ("Student T processes aren't additive" )
143
121
144
- def _build_prior (self , name , n_points , X , nu ):
122
+ def _build_prior (self , name , n_points , X , reparameterize = True ):
145
123
mu = self .mean_func (X )
146
124
chol = cholesky (stabilize (self .cov_func (X )))
125
+ if reparameterize :
126
+ chi2 = pm .ChiSquared ("chi2_" , self .nu )
127
+ v = pm .Normal (name + "_rotated_" , mu = 0.0 , sd = 1.0 , shape = n_points )
128
+ f = pm .Deterministic (name , (tt .sqrt (self .nu ) / chi2 ) * (mu + tt .dot (chol , v )))
129
+ else :
130
+ f = pm .MvStudentT (name , nu = self .nu , mu = mu , chol = chol , shape = n_points )
131
+ return f
147
132
148
- chi2 = pm .ChiSquared ("chi2_" , nu )
149
- v = pm .Normal (name + "_rotated_" , mu = 0.0 , sd = 1.0 , shape = n_points )
150
- f = pm .Deterministic (name , (tt .sqrt (nu ) / chi2 ) * (mu + tt .dot (chol , v )))
151
-
133
+ def prior (self , name , n_points , X , reparameterize = True ):
134
+ f = self ._build_prior (name , n_points , X , reparameterize )
152
135
self .X = X
153
136
self .f = f
154
- self .nu = nu
155
137
return f
156
138
157
- def prior (self , name , n_points , X , nu ):
158
- f = self ._build_prior (name , n_points , X , nu )
159
- return f
160
-
161
- def _build_conditional (self , Xnew , X , f , nu ):
162
- Kxx = self .cov_total (X )
139
+ def _build_conditional (self , Xnew , X , f ):
140
+ Kxx = self .cov_func (X )
163
141
Kxs = self .cov_func (X , Xnew )
164
142
Kss = self .cov_func (Xnew )
165
143
L = cholesky (stabilize (Kxx ))
166
144
A = solve_lower (L , Kxs )
167
145
cov = Kss - tt .dot (tt .transpose (A ), A )
168
-
169
- v = solve_lower (L , f - self .mean_total (X ))
146
+ v = solve_lower (L , f - self .mean_func (X ))
170
147
mu = self .mean_func (Xnew ) + tt .dot (tt .transpose (A ), v )
171
-
172
148
beta = tt .dot (v , v )
173
- nu2 = nu + X .shape [0 ]
174
-
175
- covT = (nu + beta - 2 )/ (nu2 - 2 ) * cov
149
+ nu2 = self .nu + X .shape [0 ]
150
+ covT = (self .nu + beta - 2 )/ (nu2 - 2 ) * cov
151
+ return nu2 , mu , covT
152
+
153
+ def conditional (self , name , n_points , Xnew ):
154
+ X = self .X
155
+ f = self .f
156
+ nu2 , mu , covT = self ._build_conditional (Xnew , X , f )
176
157
chol = cholesky (stabilize (covT ))
177
- return nu2 , mu , chol
178
-
179
- def conditional (self , name , n_points , Xnew , X = None , f = None , nu = None ):
180
- if X is None : X = self .X
181
- if f is None : f = self .f
182
- if nu is None : nu = self .nu
183
- nu2 , mu , chol = self ._build_conditional (Xnew , X , f , nu )
184
158
return pm .MvStudentT (name , nu = nu2 , mu = mu , chol = chol , shape = n_points )
185
159
186
160
187
161
@conditioned_vars (["X" , "y" , "noise" ])
188
- class Marginal (GPBase ):
162
+ class Marginal (Base ):
189
163
190
164
def __init__ (self , mean_func = None , cov_func = None ):
191
165
super (Marginal , self ).__init__ (mean_func , cov_func )
192
166
193
167
def _build_marginal_likelihood (self , X , noise ):
194
168
mu = self .mean_func (X )
195
- Kxx = self .cov_total (X )
169
+ Kxx = self .cov_func (X )
196
170
Knx = noise (X )
197
171
cov = Kxx + Knx
198
172
chol = cholesky (stabilize (cov ))
@@ -212,38 +186,53 @@ def marginal_likelihood(self, name, X, y, noise, n_points=None, is_observed=True
212
186
raise ValueError ("When `y` is not observed, `n_points` arg is required" )
213
187
return pm .MvNormal (name , mu = mu , chol = chol , size = n_points )
214
188
215
- def _build_conditional (self , Xnew , X , y , noise , pred_noise ):
216
- Kxx = self .cov_total (X )
189
+ def _build_conditional (self , Xnew , X , y , noise , cov_total , mean_total ,
190
+ pred_noise , diag = False ):
191
+ Kxx = cov_total (X )
217
192
Kxs = self .cov_func (X , Xnew )
218
- Kss = self .cov_func (Xnew )
219
193
Knx = noise (X )
220
- rxx = y - self . mean_total (X )
194
+ rxx = y - mean_total (X )
221
195
L = cholesky (stabilize (Kxx ) + Knx )
222
196
A = solve_lower (L , Kxs )
223
197
v = solve_lower (L , rxx )
224
198
mu = self .mean_func (Xnew ) + tt .dot (tt .transpose (A ), v )
225
- if pred_noise :
226
- cov = noise (Xnew ) + Kss - tt .dot (tt .transpose (A ), A )
199
+ if diag :
200
+ Kss = self .cov_func (Xnew , diag = True )
201
+ var = Kss - tt .sum (tt .square (A ), 0 )
202
+ if pred_noise :
203
+ var += noise (Xnew , diag = True )
204
+ return mu , var
227
205
else :
228
- cov = stabilize (Kss ) - tt .dot (tt .transpose (A ), A )
229
- chol = cholesky (cov )
230
- return mu , chol
231
-
232
- def conditional (self , name , n_points , Xnew , X = None , y = None ,
233
- noise = None , pred_noise = False ):
234
- if X is None : X = self .X
235
- if y is None : y = self .y
236
- if noise is None :
237
- noise = self .noise
206
+ Kss = self .cov_func (Xnew )
207
+ cov = Kss - tt .dot (tt .transpose (A ), A )
208
+ if pred_noise :
209
+ cov += noise (Xnew )
210
+ return mu , stabilize (cov )
211
+
212
+ def _get_cond_vals (self , other = None ):
213
+ if other is None :
214
+ return self .X , self .y , self .noise , self .cov_func , self .mean_func ,
238
215
else :
239
- if not isinstance (noise , Covariance ):
240
- noise = pm .gp .cov .WhiteNoise (noise )
241
- mu , chol = self ._build_conditional (Xnew , X , y , noise , pred_noise )
216
+ return other .X , other .y , other .noise , other .cov_func , other .mean_func
217
+
218
+ def conditional (self , name , n_points , Xnew , given = None , pred_noise = False ):
219
+ # try to get n_points from X, (via cast to int?), error if cant and n_points is none
220
+ X , y , noise , cov_total , mean_total = self ._get_cond_vals (given )
221
+ mu , cov = self ._build_conditional (Xnew , X , y , noise , cov_total , mean_total ,
222
+ pred_noise , diag = False )
223
+ chol = cholesky (cov )
242
224
return pm .MvNormal (name , mu = mu , chol = chol , shape = n_points )
243
225
226
+ def predict (self , Xnew , point = None , given = None , pred_noise = False , diag = False ):
227
+ X , y , noise , cov_total , mean_total = self ._get_cond_vals (given )
228
+ mu , cov = self ._build_conditional (Xnew , X , y , noise , cov_total , mean_total ,
229
+ pred_noise , diag )
230
+ mu , cov = draw_values ([mu , cov ], point = point )
231
+ return mu , cov
232
+
244
233
245
234
@conditioned_vars (["X" , "Xu" , "y" , "sigma" ])
246
- class MarginalSparse (GPBase ):
235
+ class MarginalSparse (Base ):
247
236
_available_approx = ["FITC" , "VFE" , "DTC" ]
248
237
""" FITC and VFE sparse approximations
249
238
"""
@@ -291,7 +280,7 @@ def _build_marginal_likelihood_logp(self, X, Xu, y, sigma):
291
280
quadratic = 0.5 * (tt .dot (r , r_l ) - tt .dot (c , c ))
292
281
return - 1.0 * (constant + logdet + quadratic + trace )
293
282
294
- def marginal_likelihood (self , name , n_points , X , Xu , y , sigma , is_observed = True ):
283
+ def marginal_likelihood (self , name , X , Xu , y , sigma , n_points = None , is_observed = True ):
295
284
self .X = X
296
285
self .Xu = Xu
297
286
self .y = y
@@ -300,49 +289,66 @@ def marginal_likelihood(self, name, n_points, X, Xu, y, sigma, is_observed=True)
300
289
if is_observed : # same thing ith n_points here?? check
301
290
return pm .DensityDist (name , logp , observed = y )
302
291
else :
303
- return pm .DensityDist (name , logp , size = n_points ) # need size? if not, dont need size arg
292
+ if n_points is None :
293
+ raise ValueError ("When `y` is not observed, `n_points` arg is required" )
294
+ return pm .DensityDist (name , logp , size = n_points ) # not, dont need size arg
304
295
305
- def _build_conditional (self , Xnew , Xu , X , y , sigma , pred_noise ):
296
+ def _build_conditional (self , Xnew , Xu , X , y , sigma , cov_total , mean_total ,
297
+ pred_noise , diag = False ):
306
298
sigma2 = tt .square (sigma )
307
- Kuu = self . cov_func (Xu )
308
- Kuf = self . cov_func (Xu , X )
299
+ Kuu = cov_total (Xu )
300
+ Kuf = cov_total (Xu , X )
309
301
Luu = cholesky (stabilize (Kuu ))
310
302
A = solve_lower (Luu , Kuf )
311
303
Qffd = tt .sum (A * A , 0 )
312
304
if self .approx not in self ._available_approx :
313
305
raise NotImplementedError (self .approx )
314
306
elif self .approx == "FITC" :
315
- Kffd = self . cov_func (X , diag = True )
307
+ Kffd = cov_total (X , diag = True )
316
308
Lamd = tt .clip (Kffd - Qffd , 0.0 , np .inf ) + sigma2
317
309
else : # VFE or DTC
318
310
Lamd = tt .ones_like (Qffd ) * sigma2
319
311
A_l = A / Lamd
320
312
L_B = cholesky (tt .eye (Xu .shape [0 ]) + tt .dot (A_l , tt .transpose (A )))
321
- r = y - self . mean_func (X )
313
+ r = y - mean_total (X )
322
314
r_l = r / Lamd
323
315
c = solve_lower (L_B , tt .dot (A , r_l ))
324
316
Kus = self .cov_func (Xu , Xnew )
325
317
As = solve_lower (Luu , Kus )
326
- mean = (self .mean_func (Xnew ) +
327
- tt .dot (tt .transpose (As ), solve_upper (tt .transpose (L_B ), c )))
318
+ mu = self .mean_func (Xnew ) + tt .dot (tt .transpose (As ), solve_upper (tt .transpose (L_B ), c ))
328
319
C = solve_lower (L_B , As )
329
- if pred_noise :
330
- cov = (self .cov_func (Xnew ) - tt .dot (tt .transpose (As ), As ) +
331
- tt .dot (tt .transpose (C ), C ) + sigma2 * tt .eye (Xnew .shape [0 ]))
320
+ if diag :
321
+ Kss = self .cov_func (Xnew , diag = True )
322
+ var = Kss - tt .sum (tt .sqaure (As ), 0 ) + tt .sum (tt .square (C ), 0 )
323
+ if pred_noise :
324
+ var += sigma2
325
+ return mu , var
332
326
else :
333
327
cov = (self .cov_func (Xnew ) - tt .dot (tt .transpose (As ), As ) +
334
328
tt .dot (tt .transpose (C ), C ))
335
- return mean , stabilize (cov )
336
-
337
- def conditional (self , name , n_points , Xnew , Xu = None , X = None , y = None ,
338
- sigma = None , pred_noise = False ):
339
- if Xu is None : Xu = self .Xu
340
- if X is None : X = self .X
341
- if y is None : y = self .y
342
- if sigma is None : sigma = self .sigma
343
- mu , chol = self ._build_conditional (Xnew , Xu , X , y , sigma , pred_noise )
344
- return pm .MvNormal (name , mu = mu , chol = chol , shape = n_points )
329
+ if pred_noise :
330
+ cov += sigma2 * tt .identity_like (cov )
331
+ return mu , stabilize (cov )
345
332
333
+ def _get_cond_vals (self , other = None ):
334
+ if other is None :
335
+ return self .X , self .Xu , self .y , self .sigma , self .cov_func , self .mean_func ,
336
+ else :
337
+ return other .X , self .Xu , other .y , other .sigma , other .cov_func , other .mean_func
338
+
339
+ def conditional (self , name , n_points , Xnew , given = None , pred_noise = False ):
340
+ # try to get n_points from X, (via cast to int?), error if cant and n_points is none
341
+ X , Xu , y , sigma , cov_total , mean_total = self ._get_cond_vals (given )
342
+ mu , cov = self ._build_conditional (Xnew , Xu , X , y , sigma , cov_total , mean_total ,
343
+ pred_noise , diag = False )
344
+ chol = cholesky (cov )
345
+ return pm .MvNormal (name , mu = mu , chol = chol , shape = n_points )
346
346
347
+ def predict (self , Xnew , point = None , given = None , pred_noise = False , diag = False ):
348
+ X , Xu , y , sigma , cov_total , mean_total = self ._get_cond_vals (given )
349
+ mu , cov = self ._build_conditional (Xnew , Xu , X , y , sigma , cov_total , mean_total ,
350
+ pred_noise , diag )
351
+ mu , cov = draw_values ([mu , cov ], point = point )
352
+ return mu , cov
347
353
348
354
0 commit comments