Skip to content

Commit c900d34

Browse files
committed
Merge branch 'gp-module' of github.com:bwengals/pymc3 into gp-module
2 parents 92581a9 + 9d008c5 commit c900d34

File tree

7 files changed

+711
-113
lines changed

7 files changed

+711
-113
lines changed

docs/source/notebooks/GP-Latent.ipynb

Lines changed: 28 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@
5555
"\\end{aligned}\n",
5656
"$$\n",
5757
"\n",
58-
"For more information about this reparameterization, see the secion on [drawing values from a multivariate distribution](https://en.wikipedia.org/wiki/Multivariate_normal_distribution#Drawing_values_from_the_distribution). This reparameterization can be disabled by setting the optional flag in the `prior` method, `reparameterize = False`. The default is `True`."
58+
"For more information about this reparameterization, see the section on [drawing values from a multivariate distribution](https://en.wikipedia.org/wiki/Multivariate_normal_distribution#Drawing_values_from_the_distribution). This reparameterization can be disabled by setting the optional flag in the `prior` method, `reparameterize = False`. The default is `True`."
5959
]
6060
},
6161
{
@@ -64,15 +64,15 @@
6464
"source": [
6565
"## `.conditional`\n",
6666
"\n",
67-
"The conditional method implements the \"predictive\" distribution for function values that were not necessarily part of the data set. This distribution is,\n",
67+
"The conditional method implements the \"predictive\" distribution for function values that were not part of the original data set. This distribution is,\n",
6868
"\n",
6969
"$$\n",
7070
"\\mathbf{f}_* \\mid \\mathbf{f} \\sim \\text{MvNormal} \\left(\n",
7171
" \\mathbf{m}_* + \\mathbf{K}_{*x}\\mathbf{K}_{xx}^{-1} \\mathbf{f} ,\\,\n",
7272
" \\mathbf{K}_{**} - \\mathbf{K}_{*x}\\mathbf{K}_{xx}^{-1}\\mathbf{K}_{x*} \\right)\n",
7373
"$$\n",
7474
"\n",
75-
"In PyMC3, using the same `gp` that we defined above, this is specified as,\n",
75+
"Using the same `gp` object we defined above, this is specified as,\n",
7676
"\n",
7777
"```python\n",
7878
"# vector of new X points we want to predict the function at\n",
@@ -85,7 +85,8 @@
8585
"If `gp` is part of a sum of GP objects, it can be conditioned on different components of that sum using the optional keyword argument `given`,\n",
8686
"\n",
8787
"```python\n",
88-
" f_star_diff = gp.conditional(\"f_star_diff\", n_points=100, X_star, given=a_different_gp)\n",
88+
" f_star_diff = gp.conditional(\"f_star_diff\", n_points=100, X_star, \n",
89+
" gp=a_different_gp)\n",
8990
"```"
9091
]
9192
},
@@ -125,7 +126,8 @@
125126
"ExecuteTime": {
126127
"end_time": "2017-08-14T20:08:36.301608Z",
127128
"start_time": "2017-08-14T20:08:33.664334Z"
128-
}
129+
},
130+
"collapsed": false
129131
},
130132
"outputs": [
131133
{
@@ -190,7 +192,8 @@
190192
"ExecuteTime": {
191193
"end_time": "2017-08-14T20:13:07.436337Z",
192194
"start_time": "2017-08-14T20:08:36.443723Z"
193-
}
195+
},
196+
"collapsed": false
194197
},
195198
"outputs": [
196199
{
@@ -239,7 +242,8 @@
239242
"ExecuteTime": {
240243
"end_time": "2017-08-14T20:13:08.808843Z",
241244
"start_time": "2017-08-14T20:13:07.453184Z"
242-
}
245+
},
246+
"collapsed": false
243247
},
244248
"outputs": [
245249
{
@@ -264,7 +268,8 @@
264268
"ExecuteTime": {
265269
"end_time": "2017-08-14T20:13:09.647392Z",
266270
"start_time": "2017-08-14T20:13:08.810335Z"
267-
}
271+
},
272+
"collapsed": false
268273
},
269274
"outputs": [
270275
{
@@ -313,7 +318,8 @@
313318
"ExecuteTime": {
314319
"end_time": "2017-08-14T20:13:25.150070Z",
315320
"start_time": "2017-08-14T20:13:09.648849Z"
316-
}
321+
},
322+
"collapsed": false
317323
},
318324
"outputs": [
319325
{
@@ -345,7 +351,8 @@
345351
"ExecuteTime": {
346352
"end_time": "2017-08-14T20:13:26.148611Z",
347353
"start_time": "2017-08-14T20:13:25.177747Z"
348-
}
354+
},
355+
"collapsed": false
349356
},
350357
"outputs": [
351358
{
@@ -425,7 +432,8 @@
425432
"ExecuteTime": {
426433
"end_time": "2017-08-14T20:15:45.834572Z",
427434
"start_time": "2017-08-14T20:15:45.499711Z"
428-
}
435+
},
436+
"collapsed": false
429437
},
430438
"outputs": [
431439
{
@@ -453,7 +461,8 @@
453461
"ExecuteTime": {
454462
"end_time": "2017-08-14T20:24:59.612463Z",
455463
"start_time": "2017-08-14T20:15:45.841669Z"
456-
}
464+
},
465+
"collapsed": false
457466
},
458467
"outputs": [
459468
{
@@ -496,7 +505,8 @@
496505
"ExecuteTime": {
497506
"end_time": "2017-08-14T20:25:07.307076Z",
498507
"start_time": "2017-08-14T20:24:59.619823Z"
499-
}
508+
},
509+
"collapsed": false
500510
},
501511
"outputs": [
502512
{
@@ -521,7 +531,8 @@
521531
"ExecuteTime": {
522532
"end_time": "2017-08-14T20:25:29.473243Z",
523533
"start_time": "2017-08-14T20:25:07.308607Z"
524-
}
534+
},
535+
"collapsed": false
525536
},
526537
"outputs": [
527538
{
@@ -550,7 +561,8 @@
550561
"ExecuteTime": {
551562
"end_time": "2017-08-14T20:25:30.710482Z",
552563
"start_time": "2017-08-14T20:25:29.478224Z"
553-
}
564+
},
565+
"collapsed": false
554566
},
555567
"outputs": [
556568
{
@@ -592,7 +604,7 @@
592604
"name": "python",
593605
"nbconvert_exporter": "python",
594606
"pygments_lexer": "ipython3",
595-
"version": "3.5.3"
607+
"version": "3.5.2"
596608
}
597609
},
598610
"nbformat": 4,

docs/source/notebooks/GP-Marginal.ipynb

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,8 @@
156156
"ExecuteTime": {
157157
"end_time": "2017-08-14T18:32:50.803008Z",
158158
"start_time": "2017-08-14T18:32:48.214450Z"
159-
}
159+
},
160+
"collapsed": false
160161
},
161162
"outputs": [
162163
{
@@ -208,7 +209,8 @@
208209
"ExecuteTime": {
209210
"end_time": "2017-08-14T18:32:55.433300Z",
210211
"start_time": "2017-08-14T18:32:50.805810Z"
211-
}
212+
},
213+
"collapsed": false
212214
},
213215
"outputs": [
214216
{
@@ -240,7 +242,8 @@
240242
"ExecuteTime": {
241243
"end_time": "2017-08-14T18:32:55.447681Z",
242244
"start_time": "2017-08-14T18:32:55.435521Z"
243-
}
245+
},
246+
"collapsed": false
244247
},
245248
"outputs": [
246249
{
@@ -333,7 +336,8 @@
333336
"ExecuteTime": {
334337
"end_time": "2017-08-14T18:34:37.785728Z",
335338
"start_time": "2017-08-14T18:32:55.449138Z"
336-
}
339+
},
340+
"collapsed": false
337341
},
338342
"outputs": [
339343
{
@@ -365,7 +369,8 @@
365369
"ExecuteTime": {
366370
"end_time": "2017-08-14T18:34:40.688955Z",
367371
"start_time": "2017-08-14T18:34:37.793988Z"
368-
}
372+
},
373+
"collapsed": false
369374
},
370375
"outputs": [
371376
{
@@ -412,7 +417,8 @@
412417
"ExecuteTime": {
413418
"end_time": "2017-08-14T18:36:16.040227Z",
414419
"start_time": "2017-08-14T18:34:40.690375Z"
415-
}
420+
},
421+
"collapsed": false
416422
},
417423
"outputs": [
418424
{
@@ -436,7 +442,8 @@
436442
"ExecuteTime": {
437443
"end_time": "2017-08-14T18:36:19.189319Z",
438444
"start_time": "2017-08-14T18:36:16.042082Z"
439-
}
445+
},
446+
"collapsed": false
440447
},
441448
"outputs": [
442449
{
@@ -496,7 +503,8 @@
496503
"ExecuteTime": {
497504
"end_time": "2017-08-14T18:36:20.178631Z",
498505
"start_time": "2017-08-14T18:36:19.190828Z"
499-
}
506+
},
507+
"collapsed": false
500508
},
501509
"outputs": [
502510
{
@@ -549,7 +557,7 @@
549557
"name": "python",
550558
"nbconvert_exporter": "python",
551559
"pygments_lexer": "ipython3",
552-
"version": "3.5.3"
560+
"version": "3.5.2"
553561
}
554562
},
555563
"nbformat": 4,

pymc3/gp/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,4 @@
22
from . import mean
33
from . import util
44
from .gp import Latent, Marginal, MarginalSparse, TP
5+
from .grid import Grid2DLatent

pymc3/gp/cov.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -195,22 +195,30 @@ class Stationary(Covariance):
195195
196196
Parameters
197197
----------
198-
lengthscales: If input_dim > 1, a list or array of scalars or PyMC3 random
198+
ls : If input_dim > 1, a list or array of scalars or PyMC3 random
199199
variables. If input_dim == 1, a scalar or PyMC3 random variable.
200+
ls_inv : 1 / ls. One of ls or ls_inv must be provided.
200201
"""
201202

202-
def __init__(self, input_dim, lengthscales, active_dims=None):
203+
def __init__(self, input_dim, ls=None, ls_inv=None, active_dims=None):
203204
super(Stationary, self).__init__(input_dim, active_dims)
204-
self.lengthscales = tt.as_tensor_variable(lengthscales)
205+
if (ls is None and ls_inv is None) or (ls is not None and ls_inv is not None):
206+
raise ValueError("Only one of 'ls' or 'ls_inv' must be provided")
207+
elif ls_inv is not None:
208+
if isinstance(ls_inv, (np.ndarray, list, tuple)):
209+
ls = 1.0 / np.asarray(ls_inv)
210+
else:
211+
ls = 1.0 / ls_inv
212+
self.ls = tt.as_tensor_variable(ls)
205213

206214
def square_dist(self, X, Xs):
207-
X = tt.mul(X, 1.0 / self.lengthscales)
215+
X = tt.mul(X, 1.0 / self.ls)
208216
X2 = tt.sum(tt.square(X), 1)
209217
if Xs is None:
210218
sqd = (-2.0 * tt.dot(X, tt.transpose(X))
211219
+ (tt.reshape(X2, (-1, 1)) + tt.reshape(X2, (1, -1))))
212220
else:
213-
Xs = tt.mul(Xs, 1.0 / self.lengthscales)
221+
Xs = tt.mul(Xs, 1.0 / self.ls)
214222
Xs2 = tt.sum(tt.square(Xs), 1)
215223
sqd = (-2.0 * tt.dot(X, tt.transpose(Xs))
216224
+ (tt.reshape(X2, (-1, 1)) + tt.reshape(Xs2, (1, -1))))
@@ -228,8 +236,8 @@ def full(self, X, Xs=None):
228236

229237

230238
class Periodic(Stationary):
231-
def __init__(self, input_dim, lengthscales, period, active_dims=None):
232-
super(Periodic, self).__init__(input_dim, lengthscales, active_dims)
239+
def __init__(self, input_dim, period, ls=None, ls_inv=None, active_dims=None):
240+
super(Periodic, self).__init__(input_dim, ls, ls_inv, active_dims)
233241
self.period = period
234242
def full(self, X, Xs=None):
235243
X, Xs = self._slice(X, Xs)
@@ -238,7 +246,7 @@ def full(self, X, Xs=None):
238246
f1 = X.dimshuffle(0, 'x', 1)
239247
f2 = Xs.dimshuffle('x', 0, 1)
240248
r = np.pi * (f1 - f2) / self.period
241-
r = tt.sum(tt.square(tt.sin(r) / self.lengthscales), 2)
249+
r = tt.sum(tt.square(tt.sin(r) / self.ls), 2)
242250
return tt.exp(-0.5 * r)
243251

244252

@@ -266,8 +274,8 @@ class RatQuad(Stationary):
266274
k(x, x') = \left(1 + \frac{(x - x')^2}{2\alpha\ell^2} \right)^{-\alpha}
267275
"""
268276

269-
def __init__(self, input_dim, lengthscales, alpha, active_dims=None):
270-
super(RatQuad, self).__init__(input_dim, lengthscales, active_dims)
277+
def __init__(self, input_dim, alpha, ls, ls_inv, active_dims=None):
278+
super(RatQuad, self).__init__(input_dim, ls, ls_inv, active_dims)
271279
self.alpha = alpha
272280

273281
def full(self, X, Xs=None):

0 commit comments

Comments
 (0)