Skip to content

Commit 0a75a81

Browse files
committed
add half and float calls to lazy tensor
1 parent 98dd616 commit 0a75a81

File tree

1 file changed

+36
-0
lines changed

1 file changed

+36
-0
lines changed

gpytorch/lazy/lazy_tensor.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1122,6 +1122,42 @@ def evaluate_kernel(self):
11221122
"""
11231123
return self.representation_tree()(*self.representation())
11241124

1125+
def float(self, device_id=None):
1126+
"""
1127+
This method operates identically to :func:`torch.Tensor.float`.
1128+
"""
1129+
new_args = []
1130+
new_kwargs = {}
1131+
for arg in self._args:
1132+
if hasattr(arg, "float"):
1133+
new_args.append(arg.float())
1134+
else:
1135+
new_args.append(arg)
1136+
for name, val in self._kwargs.items():
1137+
if hasattr(val, "float"):
1138+
new_kwargs[name] = val.float()
1139+
else:
1140+
new_kwargs[name] = val
1141+
return self.__class__(*new_args, **new_kwargs)
1142+
1143+
def half(self, device_id=None):
1144+
"""
1145+
This method operates identically to :func:`torch.Tensor.half`.
1146+
"""
1147+
new_args = []
1148+
new_kwargs = {}
1149+
for arg in self._args:
1150+
if hasattr(arg, "half"):
1151+
new_args.append(arg.half())
1152+
else:
1153+
new_args.append(arg)
1154+
for name, val in self._kwargs.items():
1155+
if hasattr(val, "half"):
1156+
new_kwargs[name] = val.half()
1157+
else:
1158+
new_kwargs[name] = val
1159+
return self.__class__(*new_args, **new_kwargs)
1160+
11251161
def inv_matmul(self, right_tensor, left_tensor=None):
11261162
r"""
11271163
Computes a linear solve (w.r.t self = :math:`A`) with several right hand sides :math:`R`.

0 commit comments

Comments
 (0)