6
6
7
7
import pymc3 as pm
8
8
from pymc3 .gp .cov import Covariance
9
+ from pymc3 .gp .mean import Constant
9
10
from pymc3 .gp .util import conditioned_vars
10
11
from pymc3 .distributions import draw_values
11
12
@@ -21,7 +22,7 @@ def stabilize(K):
21
22
return K + 1e-6 * tt .identity_like (K )
22
23
23
24
24
- class GPBase (object ):
25
+ class Base (object ):
25
26
"""
26
27
Base class
27
28
"""
@@ -37,28 +38,6 @@ def __init__(self, mean_func=None, cov_func=None):
37
38
self .mean_func = mean_func
38
39
self .cov_func = cov_func
39
40
40
- #@property
41
- #def cov_total(self):
42
- # total = getattr(self, "_cov_total", None)
43
- # if total is None:
44
- # return self.cov_func
45
- # else:
46
- # return total
47
- #@cov_total.setter
48
- #def cov_total(self, new_cov_total):
49
- # self._cov_total = new_cov_total
50
-
51
- #@property
52
- #def mean_total(self):
53
- # total = getattr(self, "_mean_total", None)
54
- # if total is None:
55
- # return self.mean_func
56
- # else:
57
- # return total
58
- #@mean_total.setter
59
- #def mean_total(self, new_mean_total):
60
- # self._mean_total = new_mean_total
61
-
62
41
def __add__ (self , other ):
63
42
same_attrs = set (self .__dict__ .keys ()) == set (other .__dict__ .keys ())
64
43
if not isinstance (self , type (other )) and not same_attrs :
@@ -78,7 +57,7 @@ def conditional(self, name, n_points, Xnew, *args, **kwargs):
78
57
79
58
80
59
@conditioned_vars (["X" , "f" ])
81
- class Latent (GPBase ):
60
+ class Latent (Base ):
82
61
""" Where the GP f isnt integrated out, and is sampled explicitly
83
62
"""
84
63
def __init__ (self , mean_func = None , cov_func = None ):
@@ -92,40 +71,46 @@ def _build_prior(self, name, n_points, X, reparameterize=True):
92
71
f = pm .Deterministic (name , mu + tt .dot (chol , v ))
93
72
else :
94
73
f = pm .MvNormal (name , mu = mu , chol = chol , shape = n_points )
95
- self .X = X
96
- self .f = f
97
74
return f
98
75
99
76
def prior (self , name , n_points , X , reparameterize = True ):
100
77
f = self ._build_prior (name , n_points , X , reparameterize )
78
+ self .X = X
79
+ self .f = f
101
80
return f
102
81
103
- def _build_conditional (self , Xnew , X , f ):
104
- Kxx = self .cov_total (X )
82
+ def _get_cond_vals (self , other = None ):
83
+ if other is None :
84
+ return self .X , self .f , self .cov_func , self .mean_func ,
85
+ else :
86
+ return other .X , other .f , other .cov_func , other .mean_func
87
+
88
+ def _build_conditional (self , Xnew , X , f , cov_total , mean_total , diag = False ):
89
+ Kxx = cov_total (X )
105
90
Kxs = self .cov_func (X , Xnew )
106
- Kss = self .cov_func (Xnew )
107
91
L = cholesky (stabilize (Kxx ))
108
92
A = solve_lower (L , Kxs )
109
- cov = Kss - tt .dot (tt .transpose (A ), A )
110
- chol = cholesky (stabilize (cov ))
111
- v = solve_lower (L , f - self .mean_total (X ))
93
+ v = solve_lower (L , f - mean_total (X ))
112
94
mu = self .mean_func (Xnew ) + tt .dot (tt .transpose (A ), v )
113
- return mu , chol
95
+ if diag :
96
+ Kss = self .cov_func (Xnew , diag = True )
97
+ cov = Kss - tt .sum (tt .square (A ), 0 )
98
+ else :
99
+ Kss = self .cov_func (Xnew )
100
+ cov = Kss - tt .dot (tt .transpose (A ), A )
101
+ return mu , cov
114
102
115
- def conditional (self , name , n_points , Xnew , X = None , f = None ):
116
- if X is None : X = self .X
117
- if f is None : f = self .f
118
- mu , chol = self . _build_conditional ( Xnew , X , f )
103
+ def conditional (self , name , n_points , Xnew , gp = None ):
104
+ X , f , cov_total , mean_total = self ._get_cond_vals ( gp )
105
+ mu , cov = self ._build_conditional ( Xnew , X , f , cov_total , mean_total )
106
+ chol = cholesky ( stabilize ( cov ) )
119
107
return pm .MvNormal (name , mu = mu , chol = chol , shape = n_points )
120
108
121
- def conditional2 (self , name , n_points , X_new , gp_fitted = None ):
122
- # cant condition on a gp that hasnt been fit, so gp should have gp.f
123
- # the changing X thing is dumb, get rid of it?
124
-
125
- if gp is not None :
126
- X = gp .X
127
- f = gp .f
128
-
109
+ def predict (self , Xnew , point = None , gp = None , diag = False ):
110
+ X , f , cov_total , mean_total = self ._get_cond_vals (gp )
111
+ mu , cov = self ._build_conditional (Xnew , X , f , cov_total , mean_total , diag )
112
+ mu , cov = draw_values ([mu , cov ], point = point )
113
+ return mu , cov
129
114
130
115
131
116
@conditioned_vars (["X" , "f" , "nu" ])
@@ -149,14 +134,13 @@ def _build_prior(self, name, n_points, X, nu):
149
134
chi2 = pm .ChiSquared ("chi2_" , nu )
150
135
v = pm .Normal (name + "_rotated_" , mu = 0.0 , sd = 1.0 , shape = n_points )
151
136
f = pm .Deterministic (name , (tt .sqrt (nu ) / chi2 ) * (mu + tt .dot (chol , v )))
152
-
153
- self .X = X
154
- self .f = f
155
- self .nu = nu
156
137
return f
157
138
158
139
def prior (self , name , n_points , X , nu ):
159
140
f = self ._build_prior (name , n_points , X , nu )
141
+ self .X = X
142
+ self .nu = nu
143
+ self .f = f
160
144
return f
161
145
162
146
def _build_conditional (self , Xnew , X , f , nu ):
@@ -186,7 +170,7 @@ def conditional(self, name, n_points, Xnew, X=None, f=None, nu=None):
186
170
187
171
188
172
@conditioned_vars (["X" , "y" , "noise" ])
189
- class Marginal (GPBase ):
173
+ class Marginal (Base ):
190
174
191
175
def __init__ (self , mean_func = None , cov_func = None ):
192
176
super (Marginal , self ).__init__ (mean_func , cov_func )
@@ -215,7 +199,6 @@ def marginal_likelihood(self, name, X, y, noise, n_points=None, is_observed=True
215
199
216
200
def _build_conditional (self , Xnew , X , y , noise , cov_total , mean_total ,
217
201
pred_noise , diag = False ):
218
- # when not conditioning on another gp, cov_total = self.cov_func
219
202
Kxx = cov_total (X )
220
203
Kxs = self .cov_func (X , Xnew )
221
204
Knx = noise (X )
@@ -258,18 +241,9 @@ def predict(self, Xnew, point=None, gp=None, pred_noise=False, diag=False):
258
241
mu , cov = draw_values ([mu , cov ], point = point )
259
242
return mu , cov
260
243
261
- #def conditional(self, name, n_points, Xnew, X=None, y=None,
262
- # noise=None, pred_noise=False):
263
- # X, y, noise = self._get_cond_vals(X, y, noise)
264
- # mu, cov = self._build_conditional(Xnew, X, y, noise, pred_noise, diag=False)
265
- # chol = cholesky(cov)
266
- # return pm.MvNormal(name, mu=mu, chol=chol, shape=n_points)
267
-
268
-
269
-
270
244
271
245
@conditioned_vars (["X" , "Xu" , "y" , "sigma" ])
272
- class MarginalSparse (GPBase ):
246
+ class MarginalSparse (Base ):
273
247
_available_approx = ["FITC" , "VFE" , "DTC" ]
274
248
""" FITC and VFE sparse approximations
275
249
"""
0 commit comments