Skip to content

Commit a6c5b02

Browse files
RSid8gpleiss
andauthored
Added Piecewise Polynomial Kernel (#1738)
Co-authored-by: Geoff Pleiss <[email protected]>
1 parent 1db1744 commit a6c5b02

File tree

6 files changed

+234
-6
lines changed

6 files changed

+234
-6
lines changed

.gitignore

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
# VS Code settings stuff
99
.vscode
10-
10+
.pylintrc
1111
# Project specific
1212
gpytorch/libfft
1313
.pytest_cache

docs/source/kernels.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,12 @@ Standard Kernels
5353
.. autoclass:: PeriodicKernel
5454
:members:
5555

56+
:hidden:`PiecewisePolynomialKernel`
57+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
58+
59+
.. autoclass:: PiecewisePolynomialKernel
60+
:members:
61+
5662
:hidden:`PolynomialKernel`
5763
~~~~~~~~~~~~~~~~~~~~~~~~~~
5864

gpytorch/kernels/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from .multitask_kernel import MultitaskKernel
1919
from .newton_girard_additive_kernel import NewtonGirardAdditiveKernel
2020
from .periodic_kernel import PeriodicKernel
21+
from .piecewise_polynomial_kernel import PiecewisePolynomialKernel
2122
from .polynomial_kernel import PolynomialKernel
2223
from .polynomial_kernel_grad import PolynomialKernelGrad
2324
from .product_structure_kernel import ProductStructureKernel
@@ -50,6 +51,7 @@
5051
"MultitaskKernel",
5152
"NewtonGirardAdditiveKernel",
5253
"PeriodicKernel",
54+
"PiecewisePolynomialKernel",
5355
"PolynomialKernel",
5456
"PolynomialKernelGrad",
5557
"ProductKernel",
Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
import torch
2+
3+
from .kernel import Kernel
4+
5+
6+
class PiecewisePolynomialKernel(Kernel):
7+
r"""
8+
Computes a covariance matrix based on the Piecewise Polynomial kernel
9+
between inputs :math:`\mathbf{x_1}` and :math:`\mathbf{x_2}`:
10+
11+
.. math::
12+
13+
\begin{align}
14+
r &= \left\Vert x1 - x2 \right\Vert \\
15+
j &= \lfloor \frac{D}{2} \rfloor + q +1 \\
16+
K_{\text{ppD, 0}}(\mathbf{x_1}, \mathbf{x_2}) &= (1-r)^j_+ , \\
17+
K_{\text{ppD, 1}}(\mathbf{x_1}, \mathbf{x_2}) &= (1-r)^{j+1}_+ ((j + 1)r + 1), \\
18+
K_{\text{ppD, 2}}(\mathbf{x_1}, \mathbf{x_2}) &= (1-r)^{j+2}_+ ((1 + (j+2)r +
19+
\frac{j^2 + 4j + 3}{3}r^2), \\
20+
K_{\text{ppD, 3}}(\mathbf{x_1}, \mathbf{x_2}) &= (1-r)^{j+3}_+
21+
(1 + (j+3)r + \frac{6j^2 + 36j + 45}{15}r^2 +
22+
\frac{j^3 + 9j^2 + 23j +15}{15}r^3) \\
23+
\end{align}
24+
25+
where :math:`K_{\text{ppD, q}}` is positive semidefinite in :math:`\mathbb{R}^{D}` and
26+
:math:`q` is the smoothness coefficient. See `Rasmussen and Williams (2006)`_ Equation 4.21.
27+
28+
.. note:: This kernel does not have an `outputscale` parameter. To add a scaling parameter,
29+
decorate this kernel with a :class:`gpytorch.kernels.ScaleKernel`.
30+
31+
:param int q: (default= 2) The smoothness parameter.
32+
:type q: int (0, 1, 2 or 3)
33+
:param ard_num_dims: (Default: `None`) Set this if you want a separate lengthscale for each
34+
input dimension. It should be `d` if :attr:`x1` is a `... x n x d` matrix.
35+
:type ard_num_dims: int, optional
36+
:param batch_shape: (Default: `None`) Set this if you want a separate lengthscale for each
37+
batch of input data. It should be `torch.Size([b1, b2])` for a `b1 x b2 x n x m` kernel output.
38+
:type batch_shape: torch.Size, optional
39+
:param active_dims: (Default: `None`) Set this if you want to
40+
compute the covariance of only a few input dimensions. The ints
41+
corresponds to the indices of the dimensions.
42+
:type active_dims: Tuple(int)
43+
:param lengthscale_prior: (Default: `None`)
44+
Set this if you want to apply a prior to the lengthscale parameter.
45+
:type lengthscale_prior: ~gpytorch.priors.Prior, optional
46+
:param lengthscale_constraint: (Default: `Positive`) Set this if you want
47+
to apply a constraint to the lengthscale parameter.
48+
:type lengthscale_constraint: ~gpytorch.constraints.Positive, optional
49+
:param eps: (Default: 1e-6) The minimum value that the lengthscale can take (prevents divide by zero errors).
50+
:type eps: float, optional
51+
52+
:var torch.Tensor lengthscale: The lengthscale parameter. Size/shape of parameter depends on the
53+
:attr:`ard_num_dims` and :attr:`batch_shape` arguments.
54+
55+
.. _Rasmussen and Williams (2006):
56+
http://www.gaussianprocess.org/gpml/
57+
58+
Example:
59+
>>> x = torch.randn(10, 5)
60+
>>> # Non-batch option
61+
>>> covar_module = gpytorch.kernels.ScaleKernel(
62+
gpytorch.kernels.PiecewisePolynomialKernel(q = 2))
63+
>>> # Non-batch: ARD (different lengthscale for each input dimension)
64+
>>> covar_module = gpytorch.kernels.ScaleKernel(
65+
gpytorch.kernels.PiecewisePolynomialKernel(q = 2, ard_num_dims=5)
66+
)
67+
>>> covar = covar_module(x) # Output: LazyTensor of size (10 x 10)
68+
>>> batch_x = torch.randn(2, 10, 5)
69+
>>> # Batch: different lengthscale for each batch
70+
>>> covar_module = gpytorch.kernels.ScaleKernel(
71+
gpytorch.kernels.PiecewisePolynomialKernel(q = 2, batch_shape=torch.Size([2]))
72+
)
73+
>>> covar = covar_module(batch_x) # Output: LazyTensor of size (2 x 10 x 10)
74+
"""
75+
has_lengthscale = True
76+
77+
def __init__(self, q=2, **kwargs):
78+
super(PiecewisePolynomialKernel, self).__init__(**kwargs)
79+
if q not in {0, 1, 2, 3}:
80+
raise ValueError("q expected to be 0, 1, 2 or 3")
81+
self.q = q
82+
83+
def fmax(self, r, j, q):
84+
return torch.max(torch.tensor(0.0), 1 - r).pow(j + q)
85+
86+
def get_cov(self, r, j, q):
87+
if q == 0:
88+
return 1
89+
if q == 1:
90+
return (j + 1) * r + 1
91+
if q == 2:
92+
return 1 + (j + 2) * r + ((j ** 2 + 4 * j + 3) / 3.0) * r ** 2
93+
if q == 3:
94+
return (
95+
1
96+
+ (j + 3) * r
97+
+ ((6 * j ** 2 + 36 * j + 45) / 15.0) * r ** 2
98+
+ ((j ** 3 + 9 * j ** 2 + 23 * j + 15) / 15.0) * r ** 3
99+
)
100+
101+
def forward(self, x1, x2, last_dim_is_batch=False, diag=False, **params):
102+
x1_ = x1.div(self.lengthscale)
103+
x2_ = x2.div(self.lengthscale)
104+
if last_dim_is_batch is True:
105+
D = x1.shape[1]
106+
else:
107+
D = x1.shape[-1]
108+
j = torch.floor(torch.tensor(D / 2.0)) + self.q + 1
109+
if last_dim_is_batch and diag:
110+
r = self.covar_dist(x1_, x2_, last_dim_is_batch=True, diag=True)
111+
elif diag:
112+
r = self.covar_dist(x1_, x2_, diag=True)
113+
elif last_dim_is_batch:
114+
r = self.covar_dist(x1_, x2_, last_dim_is_batch=True)
115+
else:
116+
r = self.covar_dist(x1_, x2_)
117+
cov_matrix = self.fmax(r, j, self.q) * self.get_cov(r, j, self.q)
118+
return cov_matrix
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
#!/usr/bin/env python3
2+
3+
import unittest
4+
5+
import torch
6+
7+
from gpytorch.kernels import PiecewisePolynomialKernel
8+
from gpytorch.test.base_kernel_test_case import BaseKernelTestCase
9+
10+
11+
class TestPiecewisePolynomialKernel(unittest.TestCase, BaseKernelTestCase):
12+
def create_kernel_no_ard(self, **kwargs):
13+
return PiecewisePolynomialKernel(q=2, **kwargs)
14+
15+
def test_computes_piecewise_polynomial_kernel(self):
16+
a = torch.tensor([[4, 1], [2, 2], [8, 0]], dtype=torch.float)
17+
b = torch.tensor([[0, 0], [2, 1], [1, 0]], dtype=torch.float)
18+
kernel = PiecewisePolynomialKernel(q=0)
19+
kernel.eval()
20+
21+
def test_r(a, b):
22+
return torch.cdist(a, b)
23+
24+
def test_get_cov(r, j, q):
25+
if q == 0:
26+
return 1
27+
if q == 1:
28+
return (j + 1) * r + 1
29+
if q == 2:
30+
return 1 + (j + 2) * r + ((j ** 2 + 4 * j + 3) / 3.0) * r ** 2
31+
if q == 3:
32+
return (
33+
1
34+
+ (j + 3) * r
35+
+ ((6 * j ** 2 + 36 * j + 45) / 15.0) * r ** 2
36+
+ ((j ** 3 + 9 * j ** 2 + 23 * j + 15) / 15.0) * r ** 3
37+
)
38+
39+
def test_fmax(r, j, q):
40+
return torch.max(torch.tensor(0.0), 1 - r).pow(j + q)
41+
42+
actual = torch.zeros(3, 3)
43+
j = torch.floor(a / 2.0).shape[-1] + kernel.q + 1
44+
r = test_r(a, b)
45+
actual = test_fmax(r, j, kernel.q) * test_get_cov(r, j, kernel.q)
46+
res = kernel(a, b).evaluate()
47+
self.assertLess(torch.norm(res - actual), 1e-5)
48+
49+
# diag
50+
actual = actual.diag()
51+
res = kernel(a, b).diag()
52+
self.assertLess(torch.norm(res - actual), 1e-5)
53+
54+
# batch_dims
55+
actual = torch.zeros(2, 3, 3)
56+
for i in range(2):
57+
actual[i] = kernel(a[:, i].unsqueeze(-1), b[:, i].unsqueeze(-1)).evaluate()
58+
59+
res = kernel(a, b, last_dim_is_batch=True).evaluate()
60+
self.assertLess(torch.norm(res - actual), 1e-5)
61+
62+
# batch_dims + diag
63+
res = kernel(a, b, last_dim_is_batch=True).diag()
64+
actual = torch.cat([actual[i].diag().unsqueeze(0) for i in range(actual.size(0))])
65+
self.assertLess(torch.norm(res - actual), 1e-5)
66+
67+
def test_piecewise_polynomial_kernel_batch(self):
68+
a = torch.tensor([[4, 2, 8], [1, 2, 3]], dtype=torch.float).view(2, 3, 1)
69+
b = torch.tensor([[0, 2, 1], [-1, 2, 0]], dtype=torch.float).view(2, 3, 1)
70+
kernel = PiecewisePolynomialKernel(q=0, batch_shape=torch.Size([2]))
71+
kernel.eval()
72+
73+
def test_r(a, b):
74+
return torch.cdist(a, b)
75+
76+
def test_get_cov(r, j, q):
77+
if q == 0:
78+
return 1
79+
if q == 1:
80+
return (j + 1) * r + 1
81+
if q == 2:
82+
return 1 + (j + 2) * r + ((j ** 2 + 4 * j + 3) / 3.0) * r ** 2
83+
if q == 3:
84+
return (
85+
1
86+
+ (j + 3) * r
87+
+ ((6 * j ** 2 + 36 * j + 45) / 15.0) * r ** 2
88+
+ ((j ** 3 + 9 * j ** 2 + 23 * j + 15) / 15.0) * r ** 3
89+
)
90+
91+
def test_fmax(r, j, q):
92+
return torch.max(torch.tensor(0.0), 1 - r).pow(j + q)
93+
94+
actual = torch.zeros(3, 3)
95+
j = torch.floor(a / 2.0).shape[-1] + kernel.q + 1
96+
r = test_r(a, b)
97+
actual = test_fmax(r, j, kernel.q) * test_get_cov(r, j, kernel.q)
98+
res = kernel(a, b).evaluate()
99+
self.assertLess(torch.norm(res - actual), 1e-5)
100+
101+
102+
if __name__ == "__main__":
103+
unittest.main()

test/kernels/test_polynomial_kernel.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,16 +25,15 @@ def test_computes_quadratic_kernel(self):
2525

2626
res = kernel(a, b).evaluate()
2727
self.assertLess(torch.norm(res - actual), 1e-5)
28-
2928
# diag
3029
res = kernel(a, b).diag()
3130
actual = actual.diag()
3231
self.assertLess(torch.norm(res - actual), 1e-5)
3332

3433
# batch_dims
3534
actual = torch.zeros(2, 3, 3)
36-
for l in range(2):
37-
actual[l] = kernel(a[:, l].unsqueeze(-1), b[:, l].unsqueeze(-1)).evaluate()
35+
for i in range(2):
36+
actual[i] = kernel(a[:, i].unsqueeze(-1), b[:, i].unsqueeze(-1)).evaluate()
3837

3938
res = kernel(a, b, last_dim_is_batch=True).evaluate()
4039
self.assertLess(torch.norm(res - actual), 1e-5)
@@ -65,8 +64,8 @@ def test_computes_cubic_kernel(self):
6564

6665
# batch_dims
6766
actual = torch.zeros(2, 3, 3)
68-
for l in range(2):
69-
actual[l] = kernel(a[:, l].unsqueeze(-1), b[:, l].unsqueeze(-1)).evaluate()
67+
for i in range(2):
68+
actual[i] = kernel(a[:, i].unsqueeze(-1), b[:, i].unsqueeze(-1)).evaluate()
7069

7170
res = kernel(a, b, last_dim_is_batch=True).evaluate()
7271
self.assertLess(torch.norm(res - actual), 1e-5)

0 commit comments

Comments
 (0)