11#!/usr/bin/env python3
22
3- from typing import List , Optional , Tuple , Union
3+ from typing import Iterable , Optional , Tuple , Union
44
55import torch
6+ from jaxtyping import Float
67from linear_operator import to_linear_operator
7- from linear_operator .operators import InterpolatedLinearOperator
8+ from linear_operator .operators import InterpolatedLinearOperator , LinearOperator
9+ from torch import Tensor
810
911from ..models .exact_prediction_strategies import InterpolatedPredictionStrategy
1012from ..utils .grid import create_grid
@@ -25,14 +27,14 @@ class GridInterpolationKernel(GridKernel):
2527 .. math::
2628
2729 \begin{equation*}
28- k(\mathbf{x_1}, \mathbf{x_2}) = \mathbf{w_{x_1}}^\top K_{U,U } \mathbf{w_{x_2}}
30+ k(\mathbf{x_1}, \mathbf{x_2}) = \mathbf{w_{x_1}}^\top K_{\boldsymbol Z, \boldsymbol Z } \mathbf{w_{x_2}}
2931 \end{equation*}
3032
3133 where
3234
33- * :math:`U ` is the set of gridded inducing points
35+ * :math:`\boldsymbol Z ` is the set of gridded inducing points
3436
35- * :math:`K_{U,U }` is the kernel matrix between the inducing points
37+ * :math:`K_{\boldsymbol Z, \boldsymbol Z }` is the kernel matrix between the inducing points
3638
3739 * :math:`\mathbf{w_{x_1}}` and :math:`\mathbf{w_{x_2}}` are sparse vectors based on
3840 :math:`\mathbf{x_1}` and :math:`\mathbf{x_2}` that apply cubic interpolation.
@@ -50,20 +52,13 @@ class GridInterpolationKernel(GridKernel):
5052 `GridInterpolationKernel` can only wrap **stationary kernels** (such as RBF, Matern,
5153 Periodic, Spectral Mixture, etc.)
5254
53- Args:
54- base_kernel (Kernel):
55- The kernel to approximate with KISS-GP
56- grid_size (Union[int, List[int]]):
57- The size of the grid in each dimension.
58- If a single int is provided, then every dimension will have the same grid size.
59- num_dims (int):
60- The dimension of the input data. Required if `grid_bounds=None`
61- grid_bounds (tuple(float, float), optional):
62- The bounds of the grid, if known (high performance mode).
63- The length of the tuple must match the number of dimensions.
64- The entries represent the min/max values for each dimension.
65- active_dims (tuple of ints, optional):
66- Passed down to the `base_kernel`.
55+ :param base_kernel: The kernel to approximate with KISS-GP.
56+ :param grid_size: The size of the grid in each dimension.
57+ If a single int is provided, then every dimension will have the same grid size.
58+ :param num_dims: The dimension of the input data. Required if `grid_bounds=None`
59+ :param grid_bounds: The bounds of the grid, if known (high performance mode).
60+ The length of the tuple must match the number of dimensions.
61+ The entries represent the min/max values for each dimension.
6762
6863 .. _Kernel Interpolation for Scalable Structured Gaussian Processes:
6964 http://proceedings.mlr.press/v37/wilson15.pdf
@@ -72,10 +67,10 @@ class GridInterpolationKernel(GridKernel):
7267 def __init__ (
7368 self ,
7469 base_kernel : Kernel ,
75- grid_size : Union [int , List [int ]],
70+ grid_size : Union [int , Iterable [int ]],
7671 num_dims : Optional [int ] = None ,
7772 grid_bounds : Optional [Tuple [float , float ]] = None ,
78- active_dims : Optional [ Tuple [ int , ...]] = None ,
73+ ** kwargs ,
7974 ):
8075 has_initialized_grid = 0
8176 grid_is_dynamic = True
@@ -116,8 +111,7 @@ def __init__(
116111 super (GridInterpolationKernel , self ).__init__ (
117112 base_kernel = base_kernel ,
118113 grid = grid ,
119- interpolation_mode = True ,
120- active_dims = active_dims ,
114+ ** kwargs ,
121115 )
122116 self .register_buffer ("has_initialized_grid" , torch .tensor (has_initialized_grid , dtype = torch .bool ))
123117
@@ -129,23 +123,26 @@ def _tight_grid_bounds(self):
129123 for bound , spacing in zip (self .grid_bounds , grid_spacings )
130124 )
131125
132- def _compute_grid (self , inputs , last_dim_is_batch = False ):
133- n_data , n_dimensions = inputs .size (- 2 ), inputs .size (- 1 )
134- if last_dim_is_batch :
135- inputs = inputs .transpose (- 1 , - 2 ).unsqueeze (- 1 )
136- n_dimensions = 1
137- batch_shape = inputs .shape [:- 2 ]
138-
126+ def _compute_grid (self , inputs ):
127+ * batch_shape , n_data , n_dimensions = inputs .shape
139128 inputs = inputs .reshape (- 1 , n_dimensions )
140129 interp_indices , interp_values = Interpolation ().interpolate (self .grid , inputs )
141130 interp_indices = interp_indices .view (* batch_shape , n_data , - 1 )
142131 interp_values = interp_values .view (* batch_shape , n_data , - 1 )
143132 return interp_indices , interp_values
144133
145- def _inducing_forward (self , last_dim_is_batch , ** params ):
146- return super ().forward (self .grid , self .grid , last_dim_is_batch = last_dim_is_batch , ** params )
134+ def _create_or_update_full_grid (self , grid : Iterable [Tensor ]):
135+ pass
136+
137+ def _validate_inputs (self , x : Float [Tensor , "... N D" ]) -> bool :
138+ return True
147139
148- def forward (self , x1 , x2 , diag = False , last_dim_is_batch = False , ** params ):
140+ def _inducing_forward (self , ** params ):
141+ return super ().forward (None , None , ** params )
142+
143+ def forward (
144+ self , x1 : Float [Tensor , "... N_1 D" ], x2 : Float [Tensor , "... N_2 D" ], diag : bool = False , ** params
145+ ) -> Float [Union [Tensor , LinearOperator ], "... N_1 N_2" ]:
149146 # See if we need to update the grid or not
150147 if self .grid_is_dynamic : # This is true if a grid_bounds wasn't passed in
151148 if torch .equal (x1 , x2 ):
@@ -180,16 +177,13 @@ def forward(self, x1, x2, diag=False, last_dim_is_batch=False, **params):
180177 )
181178 self .update_grid (grid )
182179
183- base_lazy_tsr = to_linear_operator (self ._inducing_forward (last_dim_is_batch = last_dim_is_batch , ** params ))
184- if last_dim_is_batch and base_lazy_tsr .size (- 3 ) == 1 :
185- base_lazy_tsr = base_lazy_tsr .repeat (* x1 .shape [:- 2 ], x1 .size (- 1 ), 1 , 1 )
186-
187- left_interp_indices , left_interp_values = self ._compute_grid (x1 , last_dim_is_batch )
180+ base_lazy_tsr = to_linear_operator (self ._inducing_forward (** params ))
181+ left_interp_indices , left_interp_values = self ._compute_grid (x1 )
188182 if torch .equal (x1 , x2 ):
189183 right_interp_indices = left_interp_indices
190184 right_interp_values = left_interp_values
191185 else :
192- right_interp_indices , right_interp_values = self ._compute_grid (x2 , last_dim_is_batch )
186+ right_interp_indices , right_interp_values = self ._compute_grid (x2 )
193187
194188 batch_shape = torch .broadcast_shapes (
195189 base_lazy_tsr .batch_shape ,
0 commit comments