Skip to content

Commit 05ebbf2

Browse files
authored
Refactor to Arguments to more closely match torch.tensor (#1746)
* refactor to arguments * switch to args, kwargs constructor * add _to_helper to broadcasting util
1 parent 8e03c1d commit 05ebbf2

File tree

3 files changed

+50
-13
lines changed

3 files changed

+50
-13
lines changed

gpytorch/lazy/cat_lazy_tensor.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import torch
44

55
from .. import settings
6-
from ..utils.broadcasting import _matmul_broadcast_shape, _mul_broadcast_shape
6+
from ..utils.broadcasting import _matmul_broadcast_shape, _mul_broadcast_shape, _to_helper
77
from ..utils.deprecation import bool_compat
88
from ..utils.getitem import _noop_index
99
from .lazy_tensor import LazyTensor, delazify
@@ -364,15 +364,22 @@ def devices(self):
364364
def device_count(self):
365365
return len(set(self.devices))
366366

367-
def to(self, device_id):
367+
def to(self, *args, **kwargs):
368368
"""
369-
returns a new CatLazyTensor with device_id as the output_device
369+
Returns a new CatLazyTensor with device as the output_device and dtype
370+
as the dtype.
370371
Warning: this does not move the LazyTensors in this CatLazyTensor to
371-
device_id
372+
device.
372373
"""
373-
new_kwargs = dict(self._kwargs)
374-
new_kwargs["output_device"] = device_id
375-
return self.__class__(*self._args, **new_kwargs)
374+
device, dtype = _to_helper(*args, **kwargs)
375+
376+
new_kwargs = {**self._kwargs, "output_device": device}
377+
res = self.__class__(*self._args, **new_kwargs)
378+
379+
if dtype is not None:
380+
res = res.type(dtype)
381+
382+
return res
376383

377384
def all_to(self, device_id):
378385
"""

gpytorch/lazy/lazy_tensor.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from ..functions._matmul import Matmul
2020
from ..functions._root_decomposition import RootDecomposition
2121
from ..functions._sqrt_inv_matmul import SqrtInvMatmul
22-
from ..utils.broadcasting import _matmul_broadcast_shape, _mul_broadcast_shape
22+
from ..utils.broadcasting import _matmul_broadcast_shape, _mul_broadcast_shape, _to_helper
2323
from ..utils.cholesky import psd_safe_cholesky
2424
from ..utils.deprecation import _deprecate_renamed_methods
2525
from ..utils.errors import CachingError
@@ -1859,25 +1859,31 @@ def symeig(self, eigenvectors: bool = False) -> Tuple[Tensor, Optional["LazyTens
18591859
pass
18601860
return self._symeig(eigenvectors=eigenvectors)
18611861

1862-
def to(self, device_id):
1862+
def to(self, *args, **kwargs):
18631863
"""
1864-
A device-agnostic method of moving the lazy_tensor to the specified device.
1864+
A device-agnostic method of moving the lazy_tensor to the specified device or dtype.
1865+
Note that we do NOT support non_blocking or other `torch.to` options other than
1866+
device and dtype and these options will be silently ignored.
18651867
18661868
Args:
1867-
device_id (:obj: `torch.device`): Which device to use (GPU or CPU).
1869+
device (:obj: `torch.device`): Which device to use (GPU or CPU).
1870+
dtype (:obj: `torch.dtype`): Which dtype to use (double, float, or half).
18681871
Returns:
18691872
:obj:`~gpytorch.lazy.LazyTensor`: New LazyTensor identical to self on specified device
18701873
"""
1874+
1875+
device, dtype = _to_helper(*args, **kwargs)
1876+
18711877
new_args = []
18721878
new_kwargs = {}
18731879
for arg in self._args:
18741880
if hasattr(arg, "to"):
1875-
new_args.append(arg.to(device_id))
1881+
new_args.append(arg.to(dtype=dtype, device=device))
18761882
else:
18771883
new_args.append(arg)
18781884
for name, val in self._kwargs.items():
18791885
if hasattr(val, "to"):
1880-
new_kwargs[name] = val.to(device_id)
1886+
new_kwargs[name] = val.to(dtype=dtype, device=device)
18811887
else:
18821888
new_kwargs[name] = val
18831889
return self.__class__(*new_args, **new_kwargs)

gpytorch/utils/broadcasting.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,3 +69,27 @@ def _pad_with_singletons(obj, num_singletons_before=0, num_singletons_after=0):
6969
"""
7070
new_shape = [1] * num_singletons_before + list(obj.shape) + [1] * num_singletons_after
7171
return obj.view(*new_shape)
72+
73+
74+
def _to_helper(*args, **kwargs):
75+
"""
76+
Silently plucks out dtype and devices from a list.
77+
78+
Example:
79+
>>> dtype, device = _to_helper(dtype=torch.float, device=torch.device("cpu"))
80+
>>> dtype, device = _to_helper(torch.float, torch.device("cpu"))
81+
"""
82+
dtype = kwargs.pop("dtype", None)
83+
device = kwargs.pop("device", None)
84+
85+
if dtype is None:
86+
dtype_list = [x for x in args if type(x) is torch.dtype]
87+
if len(dtype_list) > 0:
88+
dtype = dtype_list[0]
89+
90+
if device is None:
91+
device_list = [x for x in args if type(x) is torch.device]
92+
if len(device_list) > 0:
93+
device = device_list[0]
94+
95+
return device, dtype

0 commit comments

Comments
 (0)