Skip to content

Commit 2f7a3cf

Browse files
Merge branch 'master' into dGPFantasize
2 parents f50a9f8 + ba1dcdd commit 2f7a3cf

File tree

6 files changed

+307
-10
lines changed

6 files changed

+307
-10
lines changed

docs/requirements.txt

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
1-
setuptools_scm
2-
nbformat
1+
setuptools_scm<=7.1.0
32
ipython<=8.6.0
43
ipykernel<=6.17.1
5-
sphinx
6-
sphinx_rtd_theme
7-
sphinx_autodoc_typehints
8-
nbsphinx
9-
m2r2
10-
pyro-ppl
114
linear_operator>=0.4.0
5+
m2r2<=0.3.3.post2
6+
nbclient<=0.7.3
7+
nbformat<=5.8.0
8+
nbsphinx<=0.9.1
9+
platformdirs<=3.2.0
10+
pyro-ppl
11+
sphinx<=6.2.1
12+
sphinx_rtd_theme<0.5
13+
sphinx_autodoc_typehints<=1.23.0
1214
torch>=1.11

docs/source/kernels.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,11 @@ Specialty Kernels
147147
.. autoclass:: ArcKernel
148148
:members:
149149

150+
:hidden:`HammingIMQKernel`
151+
152+
..autoclass:: HammingIMQKernel
153+
:members:
154+
150155
:hidden:`IndexKernel`
151156
~~~~~~~~~~~~~~~~~~~~~~
152157

gpytorch/distributions/multivariate_normal.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,7 @@ def rsample(self, sample_shape: torch.Size = torch.Size(), base_samples: Optiona
215215
:param sample_shape: The number of samples to generate. (Default: `torch.Size([])`.)
216216
:param base_samples: The `*sample_shape x *batch_shape x N` tensor of
217217
i.i.d. (or approximately i.i.d.) standard Normal samples to
218-
reparameterize. (Defualt: None.)
218+
reparameterize. (Default: None.)
219219
:return: A `*sample_shape x *batch_shape x N` tensor of i.i.d. reparameterized samples.
220220
"""
221221
covar = self.lazy_covariance_matrix
@@ -274,7 +274,7 @@ def sample(self, sample_shape: torch.Size = torch.Size(), base_samples: Optional
274274
:param sample_shape: The number of samples to generate. (Default: `torch.Size([])`.)
275275
:param base_samples: The `*sample_shape x *batch_shape x N` tensor of
276276
i.i.d. (or approximately i.i.d.) standard Normal samples to
277-
reparameterize. (Defualt: None.)
277+
reparameterize. (Default: None.)
278278
:return: A `*sample_shape x *batch_shape x N` tensor of i.i.d. samples.
279279
"""
280280
with torch.no_grad():

gpytorch/kernels/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from .gaussian_symmetrized_kl_kernel import GaussianSymmetrizedKLKernel
99
from .grid_interpolation_kernel import GridInterpolationKernel
1010
from .grid_kernel import GridKernel
11+
from .hamming_kernel import HammingIMQKernel
1112
from .index_kernel import IndexKernel
1213
from .inducing_point_kernel import InducingPointKernel
1314
from .kernel import AdditiveKernel, Kernel, ProductKernel
@@ -43,6 +44,7 @@
4344
"GaussianSymmetrizedKLKernel",
4445
"GridKernel",
4546
"GridInterpolationKernel",
47+
"HammingIMQKernel",
4648
"IndexKernel",
4749
"InducingPointKernel",
4850
"LCMKernel",

gpytorch/kernels/hamming_kernel.py

Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
from typing import Optional
2+
3+
import torch
4+
from torch import nn, Tensor
5+
6+
from gpytorch.constraints.constraints import Interval, Positive
7+
from gpytorch.kernels.kernel import Kernel
8+
from gpytorch.priors.prior import Prior
9+
10+
11+
EMPTY_SIZE = torch.Size([])
12+
13+
14+
class HammingIMQKernel(Kernel):
15+
r"""
16+
Computes a covariance matrix based on the inverse multiquadratic Hamming kernel
17+
between inputs :math:`\mathbf{x_1}` and :math:`\mathbf{x_2}`:
18+
19+
.. math::
20+
\begin{equation*}
21+
k_{\text{H-IMQ}}(\mathbf{x_1}, \mathbf{x_2}) =
22+
\left( \frac{1 + \alpha}{\alpha + d_{\text{Hamming}}(x1, x2)} \right)^\beta
23+
\end{equation*}
24+
where :math:`\alpha` and :math:`\beta` are strictly positive scale parameters.
25+
This kernel was proposed in `Biological Sequence Kernels with Guaranteed Flexibility`.
26+
See http://arxiv.org/abs/2304.03775 for more details.
27+
28+
This kernel is meant to be used for fixed-length one-hot encoded discrete sequences.
29+
Because GPyTorch is particular about dimensions, the one-hot sequence encoding should be flattened
30+
to a vector with length :math:`T \times V`, where :math:`T` is the sequence length and :math:`V` is the
31+
vocabulary size.
32+
33+
:param vocab_size: The size of the vocabulary.
34+
:param batch_shape: Set this if you want a separate kernel hyperparameters for each batch of input
35+
data. It should be :math:`B_1 \times \ldots \times B_k` if :math:`\mathbf{x_1}` is
36+
a :math:`B_1 \times \ldots \times B_k \times N \times D` tensor.
37+
:param alpha_prior: Set this if you want to apply a prior to the
38+
alpha parameter.
39+
:param: alpha_constraint: Set this if you want to apply a constraint
40+
to the alpha parameter. If None is passed, the default is `Positive()`.
41+
:param beta_prior: Set this if you want to apply a prior to the
42+
beta parameter.
43+
:param beta_constraint: Set this if you want to apply a constraint
44+
to the beta parameter. If None is passed, the default is `Positive()`.
45+
46+
Example:
47+
>>> vocab_size = 8
48+
>>> x_cat = torch.tensor([[7, 7, 7, 7], [5, 7, 3, 4]]) # batch_size x seq_length
49+
>>> x_one_hot = F.one_hot(x_cat, num_classes=vocab_size) # batch_size x seq_length x vocab_size
50+
>>> x_flat = x_one_hot.view(*x_cat.shape[:-1], -1) # batch_size x (seq_length * vocab_size)
51+
>>> covar_module = gpytorch.kernels.HammingIMQKernel(vocab_size=vocab_size)
52+
>>> covar = covar_module(x_flat) # Output: LinearOperator of size (2 x 2)
53+
"""
54+
55+
def __init__(
56+
self,
57+
vocab_size: int,
58+
batch_shape: torch.Size = EMPTY_SIZE,
59+
alpha_prior: Optional[Prior] = None,
60+
alpha_constraint: Optional[Interval] = None,
61+
beta_prior: Optional[Prior] = None,
62+
beta_constraint: Optional[Interval] = None,
63+
):
64+
super().__init__(batch_shape=batch_shape)
65+
self.vocab_size = vocab_size
66+
# add alpha (scale) parameter
67+
alpha_constraint = Positive() if alpha_constraint is None else alpha_constraint
68+
self.register_parameter(
69+
name="raw_alpha",
70+
parameter=nn.Parameter(torch.zeros(*self.batch_shape, 1)),
71+
)
72+
if alpha_prior is not None:
73+
self.register_prior("alpha_prior", alpha_prior, self._alpha_param, self._alpha_closure)
74+
self.register_constraint("raw_alpha", alpha_constraint)
75+
76+
# add beta parameter
77+
beta_constraint = Positive() if beta_constraint is None else beta_constraint
78+
self.register_parameter(
79+
name="raw_beta",
80+
parameter=nn.Parameter(torch.zeros(*self.batch_shape, 1)),
81+
)
82+
if beta_prior is not None:
83+
self.register_prior("beta_prior", beta_prior, self._beta_param, self._beta_closure)
84+
self.register_constraint("raw_beta", beta_constraint)
85+
86+
@property
87+
def alpha(self) -> Tensor:
88+
return self.raw_alpha_constraint.transform(self.raw_alpha)
89+
90+
@alpha.setter
91+
def alpha(self, value: Tensor):
92+
self._set_alpha(value)
93+
94+
def _alpha_param(self, m: Kernel) -> Tensor:
95+
# Used by the alpha_prior
96+
return m.alpha
97+
98+
def _alpha_closure(self, m: Kernel, v: Tensor) -> Tensor:
99+
# Used by the alpha_prior
100+
return m._set_alpha(v)
101+
102+
def _set_alpha(self, value: Tensor):
103+
# Used by the alpha_prior
104+
if not torch.is_tensor(value):
105+
value = torch.as_tensor(value).to(self.raw_alpha)
106+
self.initialize(raw_alpha=self.raw_alpha_constraint.inverse_transform(value))
107+
108+
@property
109+
def beta(self) -> Tensor:
110+
return self.raw_beta_constraint.transform(self.raw_beta)
111+
112+
@beta.setter
113+
def beta(self, value: Tensor):
114+
self._set_beta(value)
115+
116+
def _beta_param(self, m: Kernel) -> Tensor:
117+
# Used by the beta_prior
118+
return m.beta
119+
120+
def _beta_closure(self, m: Kernel, v: Tensor) -> Tensor:
121+
# Used by the beta_prior
122+
return m._set_beta(v)
123+
124+
def _set_beta(self, value: Tensor):
125+
# Used by the beta_prior
126+
if not torch.is_tensor(value):
127+
value = torch.as_tensor(value).to(self.raw_beta)
128+
self.initialize(raw_beta=self.raw_beta_constraint.inverse_transform(value))
129+
130+
def _imq(self, dist: Tensor) -> Tensor:
131+
return ((1 + self.alpha) / (self.alpha + dist)).pow(self.beta)
132+
133+
def forward(self, x1: Tensor, x2: Tensor, diag: bool = False, **params):
134+
# GPyTorch is pretty particular about dimensions so we need to unflatten the one-hot encoding
135+
x1 = x1.view(*x1.shape[:-1], -1, self.vocab_size)
136+
x2 = x2.view(*x2.shape[:-1], -1, self.vocab_size)
137+
138+
x1_eq_x2 = torch.equal(x1, x2)
139+
140+
if diag:
141+
if x1_eq_x2:
142+
res = ((1 + self.alpha) / self.alpha).pow(self.beta)
143+
skip_dims = [-1] * len(self.batch_shape)
144+
return res.expand(*skip_dims, x1.size(-3))
145+
else:
146+
dist = x1.size(-2) - (x1 * x2).sum(dim=(-1, -2))
147+
return self._imq(dist)
148+
149+
else:
150+
dist = hamming_dist(x1, x2, x1_eq_x2)
151+
152+
return self._imq(dist)
153+
154+
155+
def hamming_dist(x1: Tensor, x2: Tensor, x1_eq_x2: bool) -> Tensor:
156+
res = x1.size(-2) - (x1.unsqueeze(-3) * x2.unsqueeze(-4)).sum(dim=(-1, -2))
157+
if x1_eq_x2 and not x1.requires_grad and not x2.requires_grad:
158+
res.diagonal(dim1=-2, dim2=-1).fill_(0)
159+
# Zero out negative values
160+
return res.clamp_min_(0)
Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
#!/usr/bin/env python3
2+
3+
import pickle
4+
import unittest
5+
6+
import torch
7+
8+
from gpytorch.kernels import HammingIMQKernel
9+
from gpytorch.priors import GammaPrior
10+
11+
12+
class TestHammingIMQKernel(unittest.TestCase):
13+
def create_seq(self, batch_size, seq_len, vocab_size):
14+
return torch.randint(0, vocab_size, (batch_size, seq_len))
15+
16+
def create_seq_pairs(self, batch_size, seq_len, vocab_size):
17+
a = self.create_seq(batch_size, seq_len, vocab_size)
18+
b = self.create_seq(batch_size, seq_len, vocab_size)
19+
set_to_a = torch.rand(batch_size, seq_len) < 0.5
20+
b = torch.where(set_to_a, a, b)
21+
return a, b
22+
23+
def test_computes_hamming_imq_function(self):
24+
"""
25+
Create one-hot encoded discrete sequences and flatten them.
26+
Compute the pairwise Hamming distance $d$ between the sequences.
27+
Check the result of the kernel evaluation is
28+
$((1 + \alpha) / (\alpha + d))^{\beta}$.
29+
"""
30+
vocab_size = 8
31+
seq_len = 4
32+
alpha = 2.0
33+
beta = 0.5
34+
kernel = HammingIMQKernel(vocab_size=vocab_size)
35+
kernel.initialize(alpha=alpha, beta=beta)
36+
kernel.eval()
37+
38+
# Create two discrete sequences with some matches.
39+
a = torch.tensor([[7, 7, 7, 7], [5, 7, 3, 4]])
40+
b = torch.tensor(
41+
[
42+
[7, 5, 7, 4],
43+
[6, 7, 3, 7],
44+
[5, 7, 3, 4],
45+
]
46+
)
47+
48+
# Convert to one-hot representation.
49+
a_one_hot = torch.zeros(*a.shape, vocab_size)
50+
a_one_hot.scatter_(index=a.unsqueeze(-1), dim=-1, value=1)
51+
b_one_hot = torch.zeros(*b.shape, vocab_size)
52+
b_one_hot.scatter_(index=b.unsqueeze(-1), dim=-1, value=1)
53+
54+
# Flatten the one-hot representations.
55+
a_one_hot_flat = a_one_hot.view(a.size(0), -1)
56+
b_one_hot_flat = b_one_hot.view(b.size(0), -1)
57+
58+
# Compute the Hamming distance.
59+
d = seq_len - (a_one_hot.unsqueeze(-3) * b_one_hot.unsqueeze(-4)).sum(dim=(-1, -2))
60+
61+
# Compute the kernel evaluation.
62+
actual = ((1 + alpha) / (alpha + d)) ** beta
63+
res = kernel(a_one_hot_flat, b_one_hot_flat).to_dense()
64+
65+
# Check the result.
66+
self.assertLess(torch.norm(res - actual), 1e-5)
67+
68+
def test_initialize_alpha(self):
69+
"""
70+
Check that the kernel can be initialized with alpha.
71+
"""
72+
alpha = 2.0
73+
kernel = HammingIMQKernel(vocab_size=8)
74+
kernel.initialize(alpha=alpha)
75+
actual_value = torch.tensor(alpha).view_as(kernel.alpha)
76+
self.assertLess(torch.norm(kernel.alpha - actual_value), 1e-5)
77+
78+
def test_initialize_alpha_batch(self):
79+
batch_size = 2
80+
alpha = torch.rand(batch_size)
81+
kernel = HammingIMQKernel(vocab_size=8, batch_shape=torch.Size([batch_size]))
82+
kernel.initialize(alpha=alpha)
83+
actual_value = alpha.view_as(kernel.alpha)
84+
self.assertLess(torch.norm(kernel.alpha - actual_value), 1e-5)
85+
86+
def test_initialize_beta(self):
87+
"""
88+
Check that the kernel can be initialized with beta.
89+
"""
90+
beta = 0.5
91+
kernel = HammingIMQKernel(vocab_size=8)
92+
kernel.initialize(beta=beta)
93+
actual_value = torch.tensor(beta).view_as(kernel.beta)
94+
self.assertLess(torch.norm(kernel.beta - actual_value), 1e-5)
95+
96+
def test_initialize_beta_batch(self):
97+
batch_size = 2
98+
beta = torch.rand(batch_size)
99+
kernel = HammingIMQKernel(vocab_size=8, batch_shape=torch.Size([batch_size]))
100+
kernel.initialize(beta=beta)
101+
actual_value = beta.view_as(kernel.beta)
102+
self.assertLess(torch.norm(kernel.beta - actual_value), 1e-5)
103+
104+
def create_kernel_with_prior(self, alpha_prior=None, beta_prior=None):
105+
return HammingIMQKernel(
106+
vocab_size=8,
107+
alpha_prior=alpha_prior,
108+
beta_prior=beta_prior,
109+
)
110+
111+
def test_prior_type(self):
112+
self.create_kernel_with_prior()
113+
self.create_kernel_with_prior(
114+
alpha_prior=GammaPrior(1.0, 1.0),
115+
beta_prior=GammaPrior(1.0, 1.0),
116+
)
117+
self.assertRaises(TypeError, self.create_kernel_with_prior, 1)
118+
119+
def test_pickle_with_prior(self):
120+
kernel = self.create_kernel_with_prior(
121+
alpha_prior=GammaPrior(1.0, 1.0),
122+
beta_prior=GammaPrior(1.0, 1.0),
123+
)
124+
pickle.loads(pickle.dumps(kernel)) # Should be able to pickle and unpickle with a prior.
125+
126+
127+
if __name__ == "__main__":
128+
unittest.main()

0 commit comments

Comments
 (0)