Skip to content

Commit 8326781

Browse files
committed
Rework GridKernel and GridInterpolationKernel to not use last_dim_is_batch
1 parent 70af591 commit 8326781

File tree

4 files changed

+174
-192
lines changed

4 files changed

+174
-192
lines changed

gpytorch/kernels/grid_interpolation_kernel.py

Lines changed: 33 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
#!/usr/bin/env python3
22

3-
from typing import List, Optional, Tuple, Union
3+
from typing import Iterable, Optional, Tuple, Union
44

55
import torch
6+
from jaxtyping import Float
67
from linear_operator import to_linear_operator
7-
from linear_operator.operators import InterpolatedLinearOperator
8+
from linear_operator.operators import InterpolatedLinearOperator, LinearOperator
9+
from torch import Tensor
810

911
from ..models.exact_prediction_strategies import InterpolatedPredictionStrategy
1012
from ..utils.grid import create_grid
@@ -25,14 +27,14 @@ class GridInterpolationKernel(GridKernel):
2527
.. math::
2628
2729
\begin{equation*}
28-
k(\mathbf{x_1}, \mathbf{x_2}) = \mathbf{w_{x_1}}^\top K_{U,U} \mathbf{w_{x_2}}
30+
k(\mathbf{x_1}, \mathbf{x_2}) = \mathbf{w_{x_1}}^\top K_{\boldsymbol Z, \boldsymbol Z} \mathbf{w_{x_2}}
2931
\end{equation*}
3032
3133
where
3234
33-
* :math:`U` is the set of gridded inducing points
35+
* :math:`\boldsymbol Z` is the set of gridded inducing points
3436
35-
* :math:`K_{U,U}` is the kernel matrix between the inducing points
37+
* :math:`K_{\boldsymbol Z, \boldsymbol Z}` is the kernel matrix between the inducing points
3638
3739
* :math:`\mathbf{w_{x_1}}` and :math:`\mathbf{w_{x_2}}` are sparse vectors based on
3840
:math:`\mathbf{x_1}` and :math:`\mathbf{x_2}` that apply cubic interpolation.
@@ -50,20 +52,13 @@ class GridInterpolationKernel(GridKernel):
5052
`GridInterpolationKernel` can only wrap **stationary kernels** (such as RBF, Matern,
5153
Periodic, Spectral Mixture, etc.)
5254
53-
Args:
54-
base_kernel (Kernel):
55-
The kernel to approximate with KISS-GP
56-
grid_size (Union[int, List[int]]):
57-
The size of the grid in each dimension.
58-
If a single int is provided, then every dimension will have the same grid size.
59-
num_dims (int):
60-
The dimension of the input data. Required if `grid_bounds=None`
61-
grid_bounds (tuple(float, float), optional):
62-
The bounds of the grid, if known (high performance mode).
63-
The length of the tuple must match the number of dimensions.
64-
The entries represent the min/max values for each dimension.
65-
active_dims (tuple of ints, optional):
66-
Passed down to the `base_kernel`.
55+
:param base_kernel: The kernel to approximate with KISS-GP.
56+
:param grid_size: The size of the grid in each dimension.
57+
If a single int is provided, then every dimension will have the same grid size.
58+
:param num_dims: The dimension of the input data. Required if `grid_bounds=None`
59+
:param grid_bounds: The bounds of the grid, if known (high performance mode).
60+
The length of the tuple must match the number of dimensions.
61+
The entries represent the min/max values for each dimension.
6762
6863
.. _Kernel Interpolation for Scalable Structured Gaussian Processes:
6964
http://proceedings.mlr.press/v37/wilson15.pdf
@@ -72,10 +67,10 @@ class GridInterpolationKernel(GridKernel):
7267
def __init__(
7368
self,
7469
base_kernel: Kernel,
75-
grid_size: Union[int, List[int]],
70+
grid_size: Union[int, Iterable[int]],
7671
num_dims: Optional[int] = None,
7772
grid_bounds: Optional[Tuple[float, float]] = None,
78-
active_dims: Optional[Tuple[int, ...]] = None,
73+
**kwargs,
7974
):
8075
has_initialized_grid = 0
8176
grid_is_dynamic = True
@@ -116,8 +111,7 @@ def __init__(
116111
super(GridInterpolationKernel, self).__init__(
117112
base_kernel=base_kernel,
118113
grid=grid,
119-
interpolation_mode=True,
120-
active_dims=active_dims,
114+
**kwargs,
121115
)
122116
self.register_buffer("has_initialized_grid", torch.tensor(has_initialized_grid, dtype=torch.bool))
123117

@@ -129,23 +123,26 @@ def _tight_grid_bounds(self):
129123
for bound, spacing in zip(self.grid_bounds, grid_spacings)
130124
)
131125

132-
def _compute_grid(self, inputs, last_dim_is_batch=False):
133-
n_data, n_dimensions = inputs.size(-2), inputs.size(-1)
134-
if last_dim_is_batch:
135-
inputs = inputs.transpose(-1, -2).unsqueeze(-1)
136-
n_dimensions = 1
137-
batch_shape = inputs.shape[:-2]
138-
126+
def _compute_grid(self, inputs):
127+
*batch_shape, n_data, n_dimensions = inputs.shape
139128
inputs = inputs.reshape(-1, n_dimensions)
140129
interp_indices, interp_values = Interpolation().interpolate(self.grid, inputs)
141130
interp_indices = interp_indices.view(*batch_shape, n_data, -1)
142131
interp_values = interp_values.view(*batch_shape, n_data, -1)
143132
return interp_indices, interp_values
144133

145-
def _inducing_forward(self, last_dim_is_batch, **params):
146-
return super().forward(self.grid, self.grid, last_dim_is_batch=last_dim_is_batch, **params)
134+
def _create_or_update_full_grid(self, grid: Iterable[Tensor]):
135+
pass
136+
137+
def _validate_inputs(self, x: Float[Tensor, "... N D"]) -> bool:
138+
return True
147139

148-
def forward(self, x1, x2, diag=False, last_dim_is_batch=False, **params):
140+
def _inducing_forward(self, **params):
141+
return super().forward(None, None, **params)
142+
143+
def forward(
144+
self, x1: Float[Tensor, "... N_1 D"], x2: Float[Tensor, "... N_2 D"], diag: bool = False, **params
145+
) -> Float[Union[Tensor, LinearOperator], "... N_1 N_2"]:
149146
# See if we need to update the grid or not
150147
if self.grid_is_dynamic: # This is true if a grid_bounds wasn't passed in
151148
if torch.equal(x1, x2):
@@ -180,16 +177,13 @@ def forward(self, x1, x2, diag=False, last_dim_is_batch=False, **params):
180177
)
181178
self.update_grid(grid)
182179

183-
base_lazy_tsr = to_linear_operator(self._inducing_forward(last_dim_is_batch=last_dim_is_batch, **params))
184-
if last_dim_is_batch and base_lazy_tsr.size(-3) == 1:
185-
base_lazy_tsr = base_lazy_tsr.repeat(*x1.shape[:-2], x1.size(-1), 1, 1)
186-
187-
left_interp_indices, left_interp_values = self._compute_grid(x1, last_dim_is_batch)
180+
base_lazy_tsr = to_linear_operator(self._inducing_forward(**params))
181+
left_interp_indices, left_interp_values = self._compute_grid(x1)
188182
if torch.equal(x1, x2):
189183
right_interp_indices = left_interp_indices
190184
right_interp_values = left_interp_values
191185
else:
192-
right_interp_indices, right_interp_values = self._compute_grid(x2, last_dim_is_batch)
186+
right_interp_indices, right_interp_values = self._compute_grid(x2)
193187

194188
batch_shape = torch.broadcast_shapes(
195189
base_lazy_tsr.batch_shape,

0 commit comments

Comments
 (0)