@@ -30,6 +30,11 @@ class GridInterpolationVariationalStrategy(_VariationalStrategy):
3030 :param list grid_bounds: Bounds of each dimension of the grid (should be a list of (float, float) tuples)
3131 :param ~gpytorch.variational.VariationalDistribution variational_distribution: A
3232 VariationalDistribution object that represents the form of the variational distribution :math:`q(\mathbf u)`
33+
34+ :ivar grid: The grid of points that the inducing points are based on.
35+ The grid is stored as a matrix, where each column corresponds to the
36+ projection of the grid onto one dimension.
37+ :type grid: torch.Tensor (M x D)
3338 """
3439
3540 def __init__ (self , model , grid_size , grid_bounds , variational_distribution ):
@@ -51,15 +56,14 @@ def __init__(self, model, grid_size, grid_bounds, variational_distribution):
5156 model , inducing_points , variational_distribution , learn_inducing_locations = False
5257 )
5358 object .__setattr__ (self , "model" , model )
54-
5559 self .register_buffer ("grid" , grid )
5660
5761 def _compute_grid (self , inputs ):
58- n_data , n_dimensions = inputs .size ( - 2 ), inputs . size ( - 1 )
59- batch_shape = inputs . shape [: - 2 ]
62+ * batch_shape , n_data , n_dimensions = inputs .shape
63+ grid = tuple ( self . grid [..., i ] for i in range ( n_dimensions ))
6064
6165 inputs = inputs .reshape (- 1 , n_dimensions )
62- interp_indices , interp_values = Interpolation ().interpolate (self . grid , inputs )
66+ interp_indices , interp_values = Interpolation ().interpolate (grid , inputs )
6367 interp_indices = interp_indices .view (* batch_shape , n_data , - 1 )
6468 interp_values = interp_values .view (* batch_shape , n_data , - 1 )
6569
0 commit comments