| 
19 | 19 | from ..functions._matmul import Matmul  | 
20 | 20 | from ..functions._root_decomposition import RootDecomposition  | 
21 | 21 | from ..functions._sqrt_inv_matmul import SqrtInvMatmul  | 
22 |  | -from ..utils.broadcasting import _matmul_broadcast_shape, _mul_broadcast_shape  | 
 | 22 | +from ..utils.broadcasting import _matmul_broadcast_shape, _mul_broadcast_shape, _to_helper  | 
23 | 23 | from ..utils.cholesky import psd_safe_cholesky  | 
24 | 24 | from ..utils.deprecation import _deprecate_renamed_methods  | 
25 | 25 | from ..utils.errors import CachingError  | 
@@ -1859,25 +1859,31 @@ def symeig(self, eigenvectors: bool = False) -> Tuple[Tensor, Optional["LazyTens  | 
1859 | 1859 |             pass  | 
1860 | 1860 |         return self._symeig(eigenvectors=eigenvectors)  | 
1861 | 1861 | 
 
  | 
1862 |  | -    def to(self, device_id):  | 
 | 1862 | +    def to(self, *args, **kwargs):  | 
1863 | 1863 |         """  | 
1864 |  | -        A device-agnostic method of moving the lazy_tensor to the specified device.  | 
 | 1864 | +        A device-agnostic method of moving the lazy_tensor to the specified device or dtype.  | 
 | 1865 | +        Note that we do NOT support non_blocking or other `torch.to` options other than  | 
 | 1866 | +        device and dtype and these options will be silently ignored.  | 
1865 | 1867 | 
  | 
1866 | 1868 |         Args:  | 
1867 |  | -            device_id (:obj: `torch.device`): Which device to use (GPU or CPU).  | 
 | 1869 | +            device (:obj: `torch.device`): Which device to use (GPU or CPU).  | 
 | 1870 | +            dtype (:obj: `torch.dtype`): Which dtype to use (double, float, or half).  | 
1868 | 1871 |         Returns:  | 
1869 | 1872 |             :obj:`~gpytorch.lazy.LazyTensor`: New LazyTensor identical to self on specified device  | 
1870 | 1873 |         """  | 
 | 1874 | + | 
 | 1875 | +        device, dtype = _to_helper(*args, **kwargs)  | 
 | 1876 | + | 
1871 | 1877 |         new_args = []  | 
1872 | 1878 |         new_kwargs = {}  | 
1873 | 1879 |         for arg in self._args:  | 
1874 | 1880 |             if hasattr(arg, "to"):  | 
1875 |  | -                new_args.append(arg.to(device_id))  | 
 | 1881 | +                new_args.append(arg.to(dtype=dtype, device=device))  | 
1876 | 1882 |             else:  | 
1877 | 1883 |                 new_args.append(arg)  | 
1878 | 1884 |         for name, val in self._kwargs.items():  | 
1879 | 1885 |             if hasattr(val, "to"):  | 
1880 |  | -                new_kwargs[name] = val.to(device_id)  | 
 | 1886 | +                new_kwargs[name] = val.to(dtype=dtype, device=device)  | 
1881 | 1887 |             else:  | 
1882 | 1888 |                 new_kwargs[name] = val  | 
1883 | 1889 |         return self.__class__(*new_args, **new_kwargs)  | 
 | 
0 commit comments