Skip to content

Commit 92cf2d1

Browse files
committed
use type method
1 parent c20e2be commit 92cf2d1

File tree

1 file changed

+29
-65
lines changed

1 file changed

+29
-65
lines changed

gpytorch/lazy/lazy_tensor.py

Lines changed: 29 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@
3131
from ..utils.warnings import NumericalWarning
3232
from .lazy_tensor_representation_tree import LazyTensorRepresentationTree
3333

34+
_TYPES_DICT = {torch.float: "float", torch.half: "half", torch.double: "double"}
35+
3436

3537
class LazyTensor(ABC):
3638
r"""
@@ -1064,25 +1066,7 @@ def double(self, device_id=None):
10641066
"""
10651067
This method operates identically to :func:`torch.Tensor.double`.
10661068
"""
1067-
new_args = []
1068-
new_kwargs = {}
1069-
for arg in self._args:
1070-
if hasattr(arg, "double"):
1071-
try:
1072-
new_args.append(arg.clone().double())
1073-
except AttributeError:
1074-
new_args.append(deepcopy(arg).double())
1075-
else:
1076-
new_args.append(arg)
1077-
for name, val in self._kwargs.items():
1078-
if hasattr(val, "double"):
1079-
try:
1080-
new_kwargs[name] = val.clone().double()
1081-
except AttributeError:
1082-
new_kwargs[name] = deepcopy(val).double()
1083-
else:
1084-
new_kwargs[name] = val
1085-
return self.__class__(*new_args, **new_kwargs)
1069+
return self.type(torch.double)
10861070

10871071
@property
10881072
def dtype(self):
@@ -1133,49 +1117,13 @@ def float(self, device_id=None):
11331117
"""
11341118
This method operates identically to :func:`torch.Tensor.float`.
11351119
"""
1136-
new_args = []
1137-
new_kwargs = {}
1138-
for arg in self._args:
1139-
if hasattr(arg, "float"):
1140-
try:
1141-
new_args.append(arg.clone().float())
1142-
except AttributeError:
1143-
new_args.append(deepcopy(arg).float())
1144-
else:
1145-
new_args.append(arg)
1146-
for name, val in self._kwargs.items():
1147-
if hasattr(val, "float"):
1148-
try:
1149-
new_kwargs[name] = val.clone().float()
1150-
except AttributeError:
1151-
new_kwargs[name] = deepcopy(val).float()
1152-
else:
1153-
new_kwargs[name] = val
1154-
return self.__class__(*new_args, **new_kwargs)
1120+
return self.type(torch.float)
11551121

11561122
def half(self, device_id=None):
11571123
"""
11581124
This method operates identically to :func:`torch.Tensor.half`.
11591125
"""
1160-
new_args = []
1161-
new_kwargs = {}
1162-
for arg in self._args:
1163-
if hasattr(arg, "half"):
1164-
try:
1165-
new_args.append(arg.clone().half())
1166-
except AttributeError:
1167-
new_args.append(deepcopy(arg).half())
1168-
else:
1169-
new_args.append(arg)
1170-
for name, val in self._kwargs.items():
1171-
if hasattr(val, "half"):
1172-
try:
1173-
new_kwargs[name] = val.clone().half()
1174-
except AttributeError:
1175-
new_kwargs[name] = deepcopy(val).half()
1176-
else:
1177-
new_kwargs[name] = val
1178-
return self.__class__(*new_args, **new_kwargs)
1126+
return self.type(torch.half)
11791127

11801128
def inv_matmul(self, right_tensor, left_tensor=None):
11811129
r"""
@@ -1980,14 +1928,30 @@ def transpose(self, dim1, dim2):
19801928
return res
19811929

19821930
def type(self, dtype):
1983-
if dtype == torch.float:
1984-
return self.float()
1985-
elif dtype == torch.double:
1986-
return self.double()
1987-
elif dtype == torch.half:
1988-
return self.half()
1989-
else:
1990-
raise RuntimeError("Dtype", dtype, "not found.")
1931+
"""
1932+
This method operates similarly to :func:`torch.Tensor.type`.
1933+
"""
1934+
attr_flag = _TYPES_DICT[dtype]
1935+
1936+
new_args = []
1937+
new_kwargs = {}
1938+
for arg in self._args:
1939+
if hasattr(arg, attr_flag):
1940+
try:
1941+
new_args.append(arg.clone().to(dtype))
1942+
except AttributeError:
1943+
new_args.append(deepcopy(arg).to(dtype))
1944+
else:
1945+
new_args.append(arg)
1946+
for name, val in self._kwargs.items():
1947+
if hasattr(val, attr_flag):
1948+
try:
1949+
new_kwargs[name] = val.clone().to(dtype)
1950+
except AttributeError:
1951+
new_kwargs[name] = deepcopy(val).to(dtype)
1952+
else:
1953+
new_kwargs[name] = val
1954+
return self.__class__(*new_args, **new_kwargs)
19911955

19921956
def unsqueeze(self, dim):
19931957
positive_dim = (self.dim() + dim + 1) if dim < 0 else dim

0 commit comments

Comments
 (0)