@@ -195,22 +195,30 @@ class Stationary(Covariance):
195
195
196
196
Parameters
197
197
----------
198
- lengthscales : If input_dim > 1, a list or array of scalars or PyMC3 random
198
+ ls : If input_dim > 1, a list or array of scalars or PyMC3 random
199
199
variables. If input_dim == 1, a scalar or PyMC3 random variable.
200
+ ls_inv : 1 / ls. One of ls or ls_inv must be provided.
200
201
"""
201
202
202
- def __init__ (self , input_dim , lengthscales , active_dims = None ):
203
+ def __init__ (self , input_dim , ls = None , ls_inv = None , active_dims = None ):
203
204
super (Stationary , self ).__init__ (input_dim , active_dims )
204
- self .lengthscales = tt .as_tensor_variable (lengthscales )
205
+ if (ls is None and ls_inv is None ) or (ls is not None and ls_inv is not None ):
206
+ raise ValueError ("Only one of 'ls' or 'ls_inv' must be provided" )
207
+ elif ls_inv is not None :
208
+ if isinstance (ls_inv , (np .ndarray , list , tuple )):
209
+ ls = 1.0 / np .asarray (ls_inv )
210
+ else :
211
+ ls = 1.0 / ls_inv
212
+ self .ls = tt .as_tensor_variable (ls )
205
213
206
214
def square_dist (self , X , Xs ):
207
- X = tt .mul (X , 1.0 / self .lengthscales )
215
+ X = tt .mul (X , 1.0 / self .ls )
208
216
X2 = tt .sum (tt .square (X ), 1 )
209
217
if Xs is None :
210
218
sqd = (- 2.0 * tt .dot (X , tt .transpose (X ))
211
219
+ (tt .reshape (X2 , (- 1 , 1 )) + tt .reshape (X2 , (1 , - 1 ))))
212
220
else :
213
- Xs = tt .mul (Xs , 1.0 / self .lengthscales )
221
+ Xs = tt .mul (Xs , 1.0 / self .ls )
214
222
Xs2 = tt .sum (tt .square (Xs ), 1 )
215
223
sqd = (- 2.0 * tt .dot (X , tt .transpose (Xs ))
216
224
+ (tt .reshape (X2 , (- 1 , 1 )) + tt .reshape (Xs2 , (1 , - 1 ))))
@@ -228,8 +236,8 @@ def full(self, X, Xs=None):
228
236
229
237
230
238
class Periodic (Stationary ):
231
- def __init__ (self , input_dim , lengthscales , period , active_dims = None ):
232
- super (Periodic , self ).__init__ (input_dim , lengthscales , active_dims )
239
+ def __init__ (self , input_dim , period , ls = None , ls_inv = None , active_dims = None ):
240
+ super (Periodic , self ).__init__ (input_dim , ls , ls_inv , active_dims )
233
241
self .period = period
234
242
def full (self , X , Xs = None ):
235
243
X , Xs = self ._slice (X , Xs )
@@ -238,7 +246,7 @@ def full(self, X, Xs=None):
238
246
f1 = X .dimshuffle (0 , 'x' , 1 )
239
247
f2 = Xs .dimshuffle ('x' , 0 , 1 )
240
248
r = np .pi * (f1 - f2 ) / self .period
241
- r = tt .sum (tt .square (tt .sin (r ) / self .lengthscales ), 2 )
249
+ r = tt .sum (tt .square (tt .sin (r ) / self .ls ), 2 )
242
250
return tt .exp (- 0.5 * r )
243
251
244
252
@@ -266,8 +274,8 @@ class RatQuad(Stationary):
266
274
k(x, x') = \left(1 + \frac{(x - x')^2}{2\alpha\ell^2} \right)^{-\alpha}
267
275
"""
268
276
269
- def __init__ (self , input_dim , lengthscales , alpha , active_dims = None ):
270
- super (RatQuad , self ).__init__ (input_dim , lengthscales , active_dims )
277
+ def __init__ (self , input_dim , alpha , ls , ls_inv , active_dims = None ):
278
+ super (RatQuad , self ).__init__ (input_dim , ls , ls_inv , active_dims )
271
279
self .alpha = alpha
272
280
273
281
def full (self , X , Xs = None ):
0 commit comments