|
| 1 | +from typing import Optional |
| 2 | + |
| 3 | +import torch |
| 4 | +from torch import nn, Tensor |
| 5 | + |
| 6 | +from gpytorch.constraints.constraints import Interval, Positive |
| 7 | +from gpytorch.kernels.kernel import Kernel |
| 8 | +from gpytorch.priors.prior import Prior |
| 9 | + |
| 10 | + |
| 11 | +EMPTY_SIZE = torch.Size([]) |
| 12 | + |
| 13 | + |
| 14 | +class HammingIMQKernel(Kernel): |
| 15 | + r""" |
| 16 | + Computes a covariance matrix based on the inverse multiquadratic Hamming kernel |
| 17 | + between inputs :math:`\mathbf{x_1}` and :math:`\mathbf{x_2}`: |
| 18 | +
|
| 19 | + .. math:: |
| 20 | + \begin{equation*} |
| 21 | + k_{\text{H-IMQ}}(\mathbf{x_1}, \mathbf{x_2}) = |
| 22 | + \left( \frac{1 + \alpha}{\alpha + d_{\text{Hamming}}(x1, x2)} \right)^\beta |
| 23 | + \end{equation*} |
| 24 | + where :math:`\alpha` and :math:`\beta` are strictly positive scale parameters. |
| 25 | + This kernel was proposed in `Biological Sequence Kernels with Guaranteed Flexibility`. |
| 26 | + See http://arxiv.org/abs/2304.03775 for more details. |
| 27 | +
|
| 28 | + This kernel is meant to be used for fixed-length one-hot encoded discrete sequences. |
| 29 | + Because GPyTorch is particular about dimensions, the one-hot sequence encoding should be flattened |
| 30 | + to a vector with length :math:`T \times V`, where :math:`T` is the sequence length and :math:`V` is the |
| 31 | + vocabulary size. |
| 32 | +
|
| 33 | + :param vocab_size: The size of the vocabulary. |
| 34 | + :param batch_shape: Set this if you want a separate kernel hyperparameters for each batch of input |
| 35 | + data. It should be :math:`B_1 \times \ldots \times B_k` if :math:`\mathbf{x_1}` is |
| 36 | + a :math:`B_1 \times \ldots \times B_k \times N \times D` tensor. |
| 37 | + :param alpha_prior: Set this if you want to apply a prior to the |
| 38 | + alpha parameter. |
| 39 | + :param: alpha_constraint: Set this if you want to apply a constraint |
| 40 | + to the alpha parameter. If None is passed, the default is `Positive()`. |
| 41 | + :param beta_prior: Set this if you want to apply a prior to the |
| 42 | + beta parameter. |
| 43 | + :param beta_constraint: Set this if you want to apply a constraint |
| 44 | + to the beta parameter. If None is passed, the default is `Positive()`. |
| 45 | +
|
| 46 | + Example: |
| 47 | + >>> vocab_size = 8 |
| 48 | + >>> x_cat = torch.tensor([[7, 7, 7, 7], [5, 7, 3, 4]]) # batch_size x seq_length |
| 49 | + >>> x_one_hot = F.one_hot(x_cat, num_classes=vocab_size) # batch_size x seq_length x vocab_size |
| 50 | + >>> x_flat = x_one_hot.view(*x_cat.shape[:-1], -1) # batch_size x (seq_length * vocab_size) |
| 51 | + >>> covar_module = gpytorch.kernels.HammingIMQKernel(vocab_size=vocab_size) |
| 52 | + >>> covar = covar_module(x_flat) # Output: LinearOperator of size (2 x 2) |
| 53 | + """ |
| 54 | + |
| 55 | + def __init__( |
| 56 | + self, |
| 57 | + vocab_size: int, |
| 58 | + batch_shape: torch.Size = EMPTY_SIZE, |
| 59 | + alpha_prior: Optional[Prior] = None, |
| 60 | + alpha_constraint: Optional[Interval] = None, |
| 61 | + beta_prior: Optional[Prior] = None, |
| 62 | + beta_constraint: Optional[Interval] = None, |
| 63 | + ): |
| 64 | + super().__init__(batch_shape=batch_shape) |
| 65 | + self.vocab_size = vocab_size |
| 66 | + # add alpha (scale) parameter |
| 67 | + alpha_constraint = Positive() if alpha_constraint is None else alpha_constraint |
| 68 | + self.register_parameter( |
| 69 | + name="raw_alpha", |
| 70 | + parameter=nn.Parameter(torch.zeros(*self.batch_shape, 1)), |
| 71 | + ) |
| 72 | + if alpha_prior is not None: |
| 73 | + self.register_prior("alpha_prior", alpha_prior, self._alpha_param, self._alpha_closure) |
| 74 | + self.register_constraint("raw_alpha", alpha_constraint) |
| 75 | + |
| 76 | + # add beta parameter |
| 77 | + beta_constraint = Positive() if beta_constraint is None else beta_constraint |
| 78 | + self.register_parameter( |
| 79 | + name="raw_beta", |
| 80 | + parameter=nn.Parameter(torch.zeros(*self.batch_shape, 1)), |
| 81 | + ) |
| 82 | + if beta_prior is not None: |
| 83 | + self.register_prior("beta_prior", beta_prior, self._beta_param, self._beta_closure) |
| 84 | + self.register_constraint("raw_beta", beta_constraint) |
| 85 | + |
| 86 | + @property |
| 87 | + def alpha(self) -> Tensor: |
| 88 | + return self.raw_alpha_constraint.transform(self.raw_alpha) |
| 89 | + |
| 90 | + @alpha.setter |
| 91 | + def alpha(self, value: Tensor): |
| 92 | + self._set_alpha(value) |
| 93 | + |
| 94 | + def _alpha_param(self, m: Kernel) -> Tensor: |
| 95 | + # Used by the alpha_prior |
| 96 | + return m.alpha |
| 97 | + |
| 98 | + def _alpha_closure(self, m: Kernel, v: Tensor) -> Tensor: |
| 99 | + # Used by the alpha_prior |
| 100 | + return m._set_alpha(v) |
| 101 | + |
| 102 | + def _set_alpha(self, value: Tensor): |
| 103 | + # Used by the alpha_prior |
| 104 | + if not torch.is_tensor(value): |
| 105 | + value = torch.as_tensor(value).to(self.raw_alpha) |
| 106 | + self.initialize(raw_alpha=self.raw_alpha_constraint.inverse_transform(value)) |
| 107 | + |
| 108 | + @property |
| 109 | + def beta(self) -> Tensor: |
| 110 | + return self.raw_beta_constraint.transform(self.raw_beta) |
| 111 | + |
| 112 | + @beta.setter |
| 113 | + def beta(self, value: Tensor): |
| 114 | + self._set_beta(value) |
| 115 | + |
| 116 | + def _beta_param(self, m: Kernel) -> Tensor: |
| 117 | + # Used by the beta_prior |
| 118 | + return m.beta |
| 119 | + |
| 120 | + def _beta_closure(self, m: Kernel, v: Tensor) -> Tensor: |
| 121 | + # Used by the beta_prior |
| 122 | + return m._set_beta(v) |
| 123 | + |
| 124 | + def _set_beta(self, value: Tensor): |
| 125 | + # Used by the beta_prior |
| 126 | + if not torch.is_tensor(value): |
| 127 | + value = torch.as_tensor(value).to(self.raw_beta) |
| 128 | + self.initialize(raw_beta=self.raw_beta_constraint.inverse_transform(value)) |
| 129 | + |
| 130 | + def _imq(self, dist: Tensor) -> Tensor: |
| 131 | + return ((1 + self.alpha) / (self.alpha + dist)).pow(self.beta) |
| 132 | + |
| 133 | + def forward(self, x1: Tensor, x2: Tensor, diag: bool = False, **params): |
| 134 | + # GPyTorch is pretty particular about dimensions so we need to unflatten the one-hot encoding |
| 135 | + x1 = x1.view(*x1.shape[:-1], -1, self.vocab_size) |
| 136 | + x2 = x2.view(*x2.shape[:-1], -1, self.vocab_size) |
| 137 | + |
| 138 | + x1_eq_x2 = torch.equal(x1, x2) |
| 139 | + |
| 140 | + if diag: |
| 141 | + if x1_eq_x2: |
| 142 | + res = ((1 + self.alpha) / self.alpha).pow(self.beta) |
| 143 | + skip_dims = [-1] * len(self.batch_shape) |
| 144 | + return res.expand(*skip_dims, x1.size(-3)) |
| 145 | + else: |
| 146 | + dist = x1.size(-2) - (x1 * x2).sum(dim=(-1, -2)) |
| 147 | + return self._imq(dist) |
| 148 | + |
| 149 | + else: |
| 150 | + dist = hamming_dist(x1, x2, x1_eq_x2) |
| 151 | + |
| 152 | + return self._imq(dist) |
| 153 | + |
| 154 | + |
| 155 | +def hamming_dist(x1: Tensor, x2: Tensor, x1_eq_x2: bool) -> Tensor: |
| 156 | + res = x1.size(-2) - (x1.unsqueeze(-3) * x2.unsqueeze(-4)).sum(dim=(-1, -2)) |
| 157 | + if x1_eq_x2 and not x1.requires_grad and not x2.requires_grad: |
| 158 | + res.diagonal(dim1=-2, dim2=-1).fill_(0) |
| 159 | + # Zero out negative values |
| 160 | + return res.clamp_min_(0) |
0 commit comments