Skip to content

Commit 70af591

Browse files
committed
Remove convert_legacy_grid option.
**Note**: This is not a breaking change; "legacy" grids were deprecated pre v1.0.
1 parent d98ceab commit 70af591

File tree

6 files changed

+15
-26
lines changed

6 files changed

+15
-26
lines changed

gpytorch/kernels/grid_kernel.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from torch import Tensor
1010

1111
from .. import settings
12-
from ..utils.grid import convert_legacy_grid, create_data_from_grid
12+
from ..utils.grid import create_data_from_grid
1313
from .kernel import Kernel
1414

1515

@@ -58,8 +58,6 @@ def __init__(
5858
raise RuntimeError("The base_kernel for GridKernel must be stationary.")
5959

6060
super().__init__(active_dims=active_dims)
61-
if torch.is_tensor(grid):
62-
grid = convert_legacy_grid(grid)
6361
self.interpolation_mode = interpolation_mode
6462
self.base_kernel = base_kernel
6563
self.num_dims = len(grid)
@@ -84,9 +82,6 @@ def update_grid(self, grid):
8482
"""
8583
Supply a new `grid` if it ever changes.
8684
"""
87-
if torch.is_tensor(grid):
88-
grid = convert_legacy_grid(grid)
89-
9085
if len(grid) != self.num_dims:
9186
raise RuntimeError("New grid should have the same number of dimensions as before.")
9287

gpytorch/utils/grid.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -99,19 +99,13 @@ def choose_grid_size(train_inputs, ratio=1.0, kronecker_structure=True):
9999
return ratio * num_data
100100

101101

102-
def convert_legacy_grid(grid: torch.Tensor) -> List[torch.Tensor]:
103-
return [grid[:, i] for i in range(grid.size(-1))]
104-
105-
106102
def create_data_from_grid(grid: List[torch.Tensor]) -> torch.Tensor:
107103
"""
108104
:param grid: Each Tensor is a 1D set of increments for the grid in that dimension
109105
:type grid: List[torch.Tensor]
110106
:return: The set of points on the grid going by column-major order
111107
:rtype: torch.Tensor
112108
"""
113-
if torch.is_tensor(grid):
114-
grid = convert_legacy_grid(grid)
115109
ndims = len(grid)
116110
assert all(axis.dim() == 1 for axis in grid)
117111
projections = torch.meshgrid(*grid, indexing="ij")

gpytorch/utils/interpolation.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,6 @@
88
import torch
99
from linear_operator.utils.interpolation import left_interp as _left_interp, left_t_interp as _left_t_interp
1010

11-
from .grid import convert_legacy_grid
12-
1311

1412
class Interpolation(object):
1513
def _cubic_interpolation_kernel(self, scaled_grid_dist):
@@ -41,8 +39,6 @@ def _cubic_interpolation_kernel(self, scaled_grid_dist):
4139
return res
4240

4341
def interpolate(self, x_grid: List[torch.Tensor], x_target: torch.Tensor, interp_points=range(-2, 2), eps=1e-10):
44-
if torch.is_tensor(x_grid):
45-
x_grid = convert_legacy_grid(x_grid)
4642
num_target_points = x_target.size(0)
4743
num_dim = x_target.size(-1)
4844
assert num_dim == len(x_grid)

gpytorch/variational/grid_interpolation_variational_strategy.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,11 @@ class GridInterpolationVariationalStrategy(_VariationalStrategy):
3030
:param list grid_bounds: Bounds of each dimension of the grid (should be a list of (float, float) tuples)
3131
:param ~gpytorch.variational.VariationalDistribution variational_distribution: A
3232
VariationalDistribution object that represents the form of the variational distribution :math:`q(\mathbf u)`
33+
34+
:ivar grid: The grid of points that the inducing points are based on.
35+
The grid is stored as a matrix, where each column corresponds to the
36+
projection of the grid onto one dimension.
37+
:type grid: torch.Tensor (M x D)
3338
"""
3439

3540
def __init__(self, model, grid_size, grid_bounds, variational_distribution):
@@ -51,15 +56,14 @@ def __init__(self, model, grid_size, grid_bounds, variational_distribution):
5156
model, inducing_points, variational_distribution, learn_inducing_locations=False
5257
)
5358
object.__setattr__(self, "model", model)
54-
5559
self.register_buffer("grid", grid)
5660

5761
def _compute_grid(self, inputs):
58-
n_data, n_dimensions = inputs.size(-2), inputs.size(-1)
59-
batch_shape = inputs.shape[:-2]
62+
*batch_shape, n_data, n_dimensions = inputs.shape
63+
grid = tuple(self.grid[..., i] for i in range(n_dimensions))
6064

6165
inputs = inputs.reshape(-1, n_dimensions)
62-
interp_indices, interp_values = Interpolation().interpolate(self.grid, inputs)
66+
interp_indices, interp_values = Interpolation().interpolate(grid, inputs)
6367
interp_indices = interp_indices.view(*batch_shape, n_data, -1)
6468
interp_values = interp_values.view(*batch_shape, n_data, -1)
6569

test/examples/test_grid_gp_regression.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,12 +61,12 @@ def test_grid_gp_mean_abs_error(self, num_dim=1, cuda=False):
6161
device = torch.device("cuda") if cuda else torch.device("cpu")
6262
grid_bounds = [(0, 1)] if num_dim == 1 else [(0, 1), (0, 2)]
6363
grid_size = 25
64-
grid = torch.zeros(grid_size, len(grid_bounds), device=device)
64+
grid = []
6565
for i in range(len(grid_bounds)):
6666
grid_diff = float(grid_bounds[i][1] - grid_bounds[i][0]) / (grid_size - 2)
67-
grid[:, i] = torch.linspace(
67+
grid.append(torch.linspace(
6868
grid_bounds[i][0] - grid_diff, grid_bounds[i][1] + grid_diff, grid_size, device=device
69-
)
69+
))
7070

7171
train_x, train_y, test_x, test_y = make_data(grid, cuda=cuda)
7272
likelihood = gpytorch.likelihoods.GaussianLikelihood()

test/utils/test_interpolation.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,11 @@
1212
class TestCubicInterpolation(unittest.TestCase):
1313
def test_interpolation(self):
1414
x = torch.linspace(0.01, 1, 100).unsqueeze(1)
15-
grid = torch.linspace(-0.05, 1.05, 50).unsqueeze(1)
15+
grid = [torch.linspace(-0.05, 1.05, 50)]
1616
indices, values = Interpolation().interpolate(grid, x)
1717
indices = indices.squeeze_(0)
1818
values = values.squeeze_(0)
19-
test_func_grid = grid.squeeze(1).pow(2)
19+
test_func_grid = grid[0].pow(2)
2020
test_func_x = x.pow(2).squeeze(-1)
2121

2222
interp_func_x = left_interp(indices, values, test_func_grid.unsqueeze(1)).squeeze()
@@ -25,7 +25,7 @@ def test_interpolation(self):
2525

2626
def test_multidim_interpolation(self):
2727
x = torch.tensor([[0.25, 0.45, 0.65, 0.85], [0.35, 0.375, 0.4, 0.425], [0.45, 0.5, 0.55, 0.6]]).t().contiguous()
28-
grid = torch.linspace(0.0, 1.0, 11).unsqueeze(1).repeat(1, 3)
28+
grid = [torch.linspace(0.0, 1.0, 11) for _ in range(3)]
2929

3030
indices, values = Interpolation().interpolate(grid, x)
3131

0 commit comments

Comments
 (0)