Skip to content

Commit 7986d43

Browse files
Merge branch 'master' into master
2 parents aee6106 + 7648de1 commit 7986d43

File tree

5 files changed

+103
-19
lines changed

5 files changed

+103
-19
lines changed

gpytorch/lazy/interpolated_lazy_tensor.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -386,6 +386,13 @@ def diag(self):
386386
else:
387387
return super(InterpolatedLazyTensor, self).diag()
388388

389+
def double(self, device_id=None):
390+
# We need to ensure that the indices remain integers.
391+
new_lt = super().double(device_id=device_id)
392+
new_lt.left_interp_indices = new_lt.left_interp_indices.type(torch.int64)
393+
new_lt.right_interp_indices = new_lt.right_interp_indices.type(torch.int64)
394+
return new_lt
395+
389396
def matmul(self, tensor):
390397
# We're using a custom matmul here, because it is significantly faster than
391398
# what we get from the function factory.

gpytorch/lazy/lazy_tensor.py

Lines changed: 42 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import math
44
import warnings
55
from abc import ABC, abstractmethod
6+
from copy import deepcopy
67
from typing import Optional, Tuple
78

89
import torch
@@ -30,6 +31,8 @@
3031
from ..utils.warnings import NumericalWarning
3132
from .lazy_tensor_representation_tree import LazyTensorRepresentationTree
3233

34+
_TYPES_DICT = {torch.float: "float", torch.half: "half", torch.double: "double"}
35+
3336

3437
class LazyTensor(ABC):
3538
r"""
@@ -1063,19 +1066,7 @@ def double(self, device_id=None):
10631066
"""
10641067
This method operates identically to :func:`torch.Tensor.double`.
10651068
"""
1066-
new_args = []
1067-
new_kwargs = {}
1068-
for arg in self._args:
1069-
if hasattr(arg, "double"):
1070-
new_args.append(arg.double())
1071-
else:
1072-
new_args.append(arg)
1073-
for name, val in self._kwargs.items():
1074-
if hasattr(val, "double"):
1075-
new_kwargs[name] = val.double()
1076-
else:
1077-
new_kwargs[name] = val
1078-
return self.__class__(*new_args, **new_kwargs)
1069+
return self.type(torch.double)
10791070

10801071
@property
10811072
def dtype(self):
@@ -1122,6 +1113,18 @@ def evaluate_kernel(self):
11221113
"""
11231114
return self.representation_tree()(*self.representation())
11241115

1116+
def float(self, device_id=None):
1117+
"""
1118+
This method operates identically to :func:`torch.Tensor.float`.
1119+
"""
1120+
return self.type(torch.float)
1121+
1122+
def half(self, device_id=None):
1123+
"""
1124+
This method operates identically to :func:`torch.Tensor.half`.
1125+
"""
1126+
return self.type(torch.half)
1127+
11251128
def inv_matmul(self, right_tensor, left_tensor=None):
11261129
r"""
11271130
Computes a linear solve (w.r.t self = :math:`A`) with several right hand sides :math:`R`.
@@ -1924,6 +1927,32 @@ def transpose(self, dim1, dim2):
19241927

19251928
return res
19261929

1930+
def type(self, dtype):
1931+
"""
1932+
This method operates similarly to :func:`torch.Tensor.type`.
1933+
"""
1934+
attr_flag = _TYPES_DICT[dtype]
1935+
1936+
new_args = []
1937+
new_kwargs = {}
1938+
for arg in self._args:
1939+
if hasattr(arg, attr_flag):
1940+
try:
1941+
new_args.append(arg.clone().to(dtype))
1942+
except AttributeError:
1943+
new_args.append(deepcopy(arg).to(dtype))
1944+
else:
1945+
new_args.append(arg)
1946+
for name, val in self._kwargs.items():
1947+
if hasattr(val, attr_flag):
1948+
try:
1949+
new_kwargs[name] = val.clone().to(dtype)
1950+
except AttributeError:
1951+
new_kwargs[name] = deepcopy(val).to(dtype)
1952+
else:
1953+
new_kwargs[name] = val
1954+
return self.__class__(*new_args, **new_kwargs)
1955+
19271956
def unsqueeze(self, dim):
19281957
positive_dim = (self.dim() + dim + 1) if dim < 0 else dim
19291958
if positive_dim > len(self.batch_shape):

gpytorch/test/lazy_tensor_test_case.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -475,6 +475,14 @@ def test_cholesky(self):
475475
self.assertAllClose(res, actual, **self.tolerances["cholesky"])
476476
# TODO: Check gradients
477477

478+
def test_double(self):
479+
lazy_tensor = self.create_lazy_tensor()
480+
evaluated = self.evaluate_lazy_tensor(lazy_tensor)
481+
482+
res = lazy_tensor.double()
483+
actual = evaluated.double()
484+
self.assertEqual(res.dtype, actual.dtype)
485+
478486
def test_diag(self):
479487
lazy_tensor = self.create_lazy_tensor()
480488
evaluated = self.evaluate_lazy_tensor(lazy_tensor)
@@ -484,6 +492,25 @@ def test_diag(self):
484492
actual = actual.view(*lazy_tensor.batch_shape, -1)
485493
self.assertAllClose(res, actual, **self.tolerances["diag"])
486494

495+
def test_float(self):
496+
lazy_tensor = self.create_lazy_tensor().double()
497+
evaluated = self.evaluate_lazy_tensor(lazy_tensor)
498+
499+
res = lazy_tensor.float()
500+
actual = evaluated.float()
501+
self.assertEqual(res.dtype, actual.dtype)
502+
503+
def _test_half(self, lazy_tensor):
504+
evaluated = self.evaluate_lazy_tensor(lazy_tensor)
505+
506+
res = lazy_tensor.half()
507+
actual = evaluated.half()
508+
self.assertEqual(res.dtype, actual.dtype)
509+
510+
def test_half(self):
511+
lazy_tensor = self.create_lazy_tensor()
512+
self._test_half(lazy_tensor)
513+
487514
def test_inv_matmul_vector(self, cholesky=False):
488515
lazy_tensor = self.create_lazy_tensor()
489516
rhs = torch.randn(lazy_tensor.size(-1))

test/lazy/test_interpolated_lazy_tensor.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,8 @@ def create_lazy_tensor(self):
3535
)
3636

3737
def evaluate_lazy_tensor(self, lazy_tensor):
38-
left_matrix = torch.zeros(4, 6)
39-
right_matrix = torch.zeros(4, 6)
38+
left_matrix = torch.zeros(4, 6, dtype=lazy_tensor.dtype)
39+
right_matrix = torch.zeros(4, 6, dtype=lazy_tensor.dtype)
4040
left_matrix.scatter_(1, lazy_tensor.left_interp_indices, lazy_tensor.left_interp_values)
4141
right_matrix.scatter_(1, lazy_tensor.right_interp_indices, lazy_tensor.right_interp_values)
4242

@@ -75,8 +75,8 @@ def evaluate_lazy_tensor(self, lazy_tensor):
7575
left_matrix_comps = []
7676
right_matrix_comps = []
7777
for i in range(5):
78-
left_matrix_comp = torch.zeros(4, 6)
79-
right_matrix_comp = torch.zeros(4, 6)
78+
left_matrix_comp = torch.zeros(4, 6, dtype=lazy_tensor.dtype)
79+
right_matrix_comp = torch.zeros(4, 6, dtype=lazy_tensor.dtype)
8080
left_matrix_comp.scatter_(1, lazy_tensor.left_interp_indices[i], lazy_tensor.left_interp_values[i])
8181
right_matrix_comp.scatter_(1, lazy_tensor.right_interp_indices[i], lazy_tensor.right_interp_values[i])
8282
left_matrix_comps.append(left_matrix_comp.unsqueeze(0))
@@ -121,8 +121,8 @@ def evaluate_lazy_tensor(self, lazy_tensor):
121121
right_matrix_comps = []
122122
for i in range(2):
123123
for j in range(5):
124-
left_matrix_comp = torch.zeros(4, 6)
125-
right_matrix_comp = torch.zeros(4, 6)
124+
left_matrix_comp = torch.zeros(4, 6, dtype=lazy_tensor.dtype)
125+
right_matrix_comp = torch.zeros(4, 6, dtype=lazy_tensor.dtype)
126126
left_matrix_comp.scatter_(
127127
1, lazy_tensor.left_interp_indices[i, j], lazy_tensor.left_interp_values[i, j]
128128
)

test/lazy/test_lazy_evaluated_kernel_tensor.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,13 @@ def test_getitem_tensor_index(self):
127127
def test_quad_form_derivative(self):
128128
pass
129129

130+
def test_half(self):
131+
# many transform operations aren't supported in half so we overwrite
132+
# this test
133+
lazy_tensor = self.create_lazy_tensor()
134+
lazy_tensor.kernel.raw_lengthscale_constraint.transform = lambda x: x + 0.1
135+
self._test_half(lazy_tensor)
136+
130137

131138
class TestLazyEvaluatedKernelTensorMultitaskBatch(TestLazyEvaluatedKernelTensorBatch):
132139
seed = 0
@@ -140,6 +147,13 @@ def create_lazy_tensor(self):
140147
def test_inv_matmul_matrix_with_checkpointing(self):
141148
pass
142149

150+
def test_half(self):
151+
# many transform operations aren't supported in half so we overwrite
152+
# this test
153+
lazy_tensor = self.create_lazy_tensor()
154+
lazy_tensor.kernel.data_covar_module.raw_lengthscale_constraint.transform = lambda x: x + 0.1
155+
self._test_half(lazy_tensor)
156+
143157

144158
class TestLazyEvaluatedKernelTensorAdditive(TestLazyEvaluatedKernelTensorBatch):
145159
seed = 0
@@ -162,3 +176,10 @@ def evaluate_lazy_tensor(self, lazy_tensor):
162176

163177
def test_inv_matmul_matrix_with_checkpointing(self):
164178
pass
179+
180+
def test_half(self):
181+
# many transform operations aren't supported in half so we overwrite
182+
# this test
183+
lazy_tensor = self.create_lazy_tensor()
184+
lazy_tensor.kernel.base_kernel.raw_lengthscale_constraint.transform = lambda x: x + 0.1
185+
self._test_half(lazy_tensor)

0 commit comments

Comments
 (0)