Skip to content

Commit c20e2be

Browse files
committed
add float, double, half unit tests to lt testcase
1 parent cdc5292 commit c20e2be

File tree

4 files changed

+61
-6
lines changed

4 files changed

+61
-6
lines changed

gpytorch/lazy/interpolated_lazy_tensor.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -386,6 +386,13 @@ def diag(self):
386386
else:
387387
return super(InterpolatedLazyTensor, self).diag()
388388

389+
def double(self, device_id=None):
390+
# We need to ensure that the indices remain integers.
391+
new_lt = super().double(device_id=device_id)
392+
new_lt.left_interp_indices = new_lt.left_interp_indices.type(torch.int64)
393+
new_lt.right_interp_indices = new_lt.right_interp_indices.type(torch.int64)
394+
return new_lt
395+
389396
def matmul(self, tensor):
390397
# We're using a custom matmul here, because it is significantly faster than
391398
# what we get from the function factory.

gpytorch/test/lazy_tensor_test_case.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -475,6 +475,14 @@ def test_cholesky(self):
475475
self.assertAllClose(res, actual, **self.tolerances["cholesky"])
476476
# TODO: Check gradients
477477

478+
def test_double(self):
479+
lazy_tensor = self.create_lazy_tensor()
480+
evaluated = self.evaluate_lazy_tensor(lazy_tensor)
481+
482+
res = lazy_tensor.double()
483+
actual = evaluated.double()
484+
self.assertEqual(res.dtype, actual.dtype)
485+
478486
def test_diag(self):
479487
lazy_tensor = self.create_lazy_tensor()
480488
evaluated = self.evaluate_lazy_tensor(lazy_tensor)
@@ -484,6 +492,25 @@ def test_diag(self):
484492
actual = actual.view(*lazy_tensor.batch_shape, -1)
485493
self.assertAllClose(res, actual, **self.tolerances["diag"])
486494

495+
def test_float(self):
496+
lazy_tensor = self.create_lazy_tensor().double()
497+
evaluated = self.evaluate_lazy_tensor(lazy_tensor)
498+
499+
res = lazy_tensor.float()
500+
actual = evaluated.float()
501+
self.assertEqual(res.dtype, actual.dtype)
502+
503+
def _test_half(self, lazy_tensor):
504+
evaluated = self.evaluate_lazy_tensor(lazy_tensor)
505+
506+
res = lazy_tensor.half()
507+
actual = evaluated.half()
508+
self.assertEqual(res.dtype, actual.dtype)
509+
510+
def test_half(self):
511+
lazy_tensor = self.create_lazy_tensor()
512+
self._test_half(lazy_tensor)
513+
487514
def test_inv_matmul_vector(self, cholesky=False):
488515
lazy_tensor = self.create_lazy_tensor()
489516
rhs = torch.randn(lazy_tensor.size(-1))

test/lazy/test_interpolated_lazy_tensor.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,8 @@ def create_lazy_tensor(self):
3535
)
3636

3737
def evaluate_lazy_tensor(self, lazy_tensor):
38-
left_matrix = torch.zeros(4, 6)
39-
right_matrix = torch.zeros(4, 6)
38+
left_matrix = torch.zeros(4, 6, dtype=lazy_tensor.dtype)
39+
right_matrix = torch.zeros(4, 6, dtype=lazy_tensor.dtype)
4040
left_matrix.scatter_(1, lazy_tensor.left_interp_indices, lazy_tensor.left_interp_values)
4141
right_matrix.scatter_(1, lazy_tensor.right_interp_indices, lazy_tensor.right_interp_values)
4242

@@ -75,8 +75,8 @@ def evaluate_lazy_tensor(self, lazy_tensor):
7575
left_matrix_comps = []
7676
right_matrix_comps = []
7777
for i in range(5):
78-
left_matrix_comp = torch.zeros(4, 6)
79-
right_matrix_comp = torch.zeros(4, 6)
78+
left_matrix_comp = torch.zeros(4, 6, dtype=lazy_tensor.dtype)
79+
right_matrix_comp = torch.zeros(4, 6, dtype=lazy_tensor.dtype)
8080
left_matrix_comp.scatter_(1, lazy_tensor.left_interp_indices[i], lazy_tensor.left_interp_values[i])
8181
right_matrix_comp.scatter_(1, lazy_tensor.right_interp_indices[i], lazy_tensor.right_interp_values[i])
8282
left_matrix_comps.append(left_matrix_comp.unsqueeze(0))
@@ -121,8 +121,8 @@ def evaluate_lazy_tensor(self, lazy_tensor):
121121
right_matrix_comps = []
122122
for i in range(2):
123123
for j in range(5):
124-
left_matrix_comp = torch.zeros(4, 6)
125-
right_matrix_comp = torch.zeros(4, 6)
124+
left_matrix_comp = torch.zeros(4, 6, dtype=lazy_tensor.dtype)
125+
right_matrix_comp = torch.zeros(4, 6, dtype=lazy_tensor.dtype)
126126
left_matrix_comp.scatter_(
127127
1, lazy_tensor.left_interp_indices[i, j], lazy_tensor.left_interp_values[i, j]
128128
)

test/lazy/test_lazy_evaluated_kernel_tensor.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,13 @@ def test_getitem_tensor_index(self):
127127
def test_quad_form_derivative(self):
128128
pass
129129

130+
def test_half(self):
131+
# many transform operations aren't supported in half so we overwrite
132+
# this test
133+
lazy_tensor = self.create_lazy_tensor()
134+
lazy_tensor.kernel.raw_lengthscale_constraint.transform = lambda x: x + 0.1
135+
self._test_half(lazy_tensor)
136+
130137

131138
class TestLazyEvaluatedKernelTensorMultitaskBatch(TestLazyEvaluatedKernelTensorBatch):
132139
seed = 0
@@ -140,6 +147,13 @@ def create_lazy_tensor(self):
140147
def test_inv_matmul_matrix_with_checkpointing(self):
141148
pass
142149

150+
def test_half(self):
151+
# many transform operations aren't supported in half so we overwrite
152+
# this test
153+
lazy_tensor = self.create_lazy_tensor()
154+
lazy_tensor.kernel.data_covar_module.raw_lengthscale_constraint.transform = lambda x: x + 0.1
155+
self._test_half(lazy_tensor)
156+
143157

144158
class TestLazyEvaluatedKernelTensorAdditive(TestLazyEvaluatedKernelTensorBatch):
145159
seed = 0
@@ -162,3 +176,10 @@ def evaluate_lazy_tensor(self, lazy_tensor):
162176

163177
def test_inv_matmul_matrix_with_checkpointing(self):
164178
pass
179+
180+
def test_half(self):
181+
# many transform operations aren't supported in half so we overwrite
182+
# this test
183+
lazy_tensor = self.create_lazy_tensor()
184+
lazy_tensor.kernel.base_kernel.raw_lengthscale_constraint.transform = lambda x: x + 0.1
185+
self._test_half(lazy_tensor)

0 commit comments

Comments
 (0)