Skip to content

Commit 32006a4

Browse files
committed
allow inverse lengthscale, shorten lengthscale varname
1 parent 777ff26 commit 32006a4

File tree

1 file changed

+18
-10
lines changed

1 file changed

+18
-10
lines changed

pymc3/gp/cov.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -195,22 +195,30 @@ class Stationary(Covariance):
195195
196196
Parameters
197197
----------
198-
lengthscales: If input_dim > 1, a list or array of scalars or PyMC3 random
198+
ls : If input_dim > 1, a list or array of scalars or PyMC3 random
199199
variables. If input_dim == 1, a scalar or PyMC3 random variable.
200+
ls_inv : 1 / ls. One of ls or ls_inv must be provided.
200201
"""
201202

202-
def __init__(self, input_dim, lengthscales, active_dims=None):
203+
def __init__(self, input_dim, ls=None, ls_inv=None, active_dims=None):
203204
super(Stationary, self).__init__(input_dim, active_dims)
204-
self.lengthscales = tt.as_tensor_variable(lengthscales)
205+
if (ls is None and ls_inv is None) or (ls is not None and ls_inv is not None):
206+
raise ValueError("Only one of 'ls' or 'ls_inv' must be provided")
207+
elif ls_inv is not None:
208+
if isinstance(ls_inv, (np.ndarray, list, tuple)):
209+
ls = 1.0 / np.asarray(ls_inv)
210+
else:
211+
ls = 1.0 / ls_inv
212+
self.ls = tt.as_tensor_variable(ls)
205213

206214
def square_dist(self, X, Xs):
207-
X = tt.mul(X, 1.0 / self.lengthscales)
215+
X = tt.mul(X, 1.0 / self.ls)
208216
X2 = tt.sum(tt.square(X), 1)
209217
if Xs is None:
210218
sqd = (-2.0 * tt.dot(X, tt.transpose(X))
211219
+ (tt.reshape(X2, (-1, 1)) + tt.reshape(X2, (1, -1))))
212220
else:
213-
Xs = tt.mul(Xs, 1.0 / self.lengthscales)
221+
Xs = tt.mul(Xs, 1.0 / self.ls)
214222
Xs2 = tt.sum(tt.square(Xs), 1)
215223
sqd = (-2.0 * tt.dot(X, tt.transpose(Xs))
216224
+ (tt.reshape(X2, (-1, 1)) + tt.reshape(Xs2, (1, -1))))
@@ -228,8 +236,8 @@ def full(self, X, Xs=None):
228236

229237

230238
class Periodic(Stationary):
231-
def __init__(self, input_dim, lengthscales, period, active_dims=None):
232-
super(Periodic, self).__init__(input_dim, lengthscales, active_dims)
239+
def __init__(self, input_dim, period, ls=None, ls_inv=None, active_dims=None):
240+
super(Periodic, self).__init__(input_dim, ls, ls_inv, active_dims)
233241
self.period = period
234242
def full(self, X, Xs=None):
235243
X, Xs = self._slice(X, Xs)
@@ -238,7 +246,7 @@ def full(self, X, Xs=None):
238246
f1 = X.dimshuffle(0, 'x', 1)
239247
f2 = Xs.dimshuffle('x', 0, 1)
240248
r = np.pi * (f1 - f2) / self.period
241-
r = tt.sum(tt.square(tt.sin(r) / self.lengthscales), 2)
249+
r = tt.sum(tt.square(tt.sin(r) / self.ls), 2)
242250
return tt.exp(-0.5 * r)
243251

244252

@@ -266,8 +274,8 @@ class RatQuad(Stationary):
266274
k(x, x') = \left(1 + \frac{(x - x')^2}{2\alpha\ell^2} \right)^{-\alpha}
267275
"""
268276

269-
def __init__(self, input_dim, lengthscales, alpha, active_dims=None):
270-
super(RatQuad, self).__init__(input_dim, lengthscales, active_dims)
277+
def __init__(self, input_dim, alpha, ls, ls_inv, active_dims=None):
278+
super(RatQuad, self).__init__(input_dim, ls, ls_inv, active_dims)
271279
self.alpha = alpha
272280

273281
def full(self, X, Xs=None):

0 commit comments

Comments
 (0)