Skip to content

Commit cdc5292

Browse files
committed
add check for deepcopy in type attrs
1 parent 7c0fd4c commit cdc5292

File tree

1 file changed

+25
-6
lines changed

1 file changed

+25
-6
lines changed

gpytorch/lazy/lazy_tensor.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import math
44
import warnings
55
from abc import ABC, abstractmethod
6+
from copy import deepcopy
67
from typing import Optional, Tuple
78

89
import torch
@@ -1067,12 +1068,18 @@ def double(self, device_id=None):
10671068
new_kwargs = {}
10681069
for arg in self._args:
10691070
if hasattr(arg, "double"):
1070-
new_args.append(arg.double())
1071+
try:
1072+
new_args.append(arg.clone().double())
1073+
except AttributeError:
1074+
new_args.append(deepcopy(arg).double())
10711075
else:
10721076
new_args.append(arg)
10731077
for name, val in self._kwargs.items():
10741078
if hasattr(val, "double"):
1075-
new_kwargs[name] = val.double()
1079+
try:
1080+
new_kwargs[name] = val.clone().double()
1081+
except AttributeError:
1082+
new_kwargs[name] = deepcopy(val).double()
10761083
else:
10771084
new_kwargs[name] = val
10781085
return self.__class__(*new_args, **new_kwargs)
@@ -1130,12 +1137,18 @@ def float(self, device_id=None):
11301137
new_kwargs = {}
11311138
for arg in self._args:
11321139
if hasattr(arg, "float"):
1133-
new_args.append(arg.float())
1140+
try:
1141+
new_args.append(arg.clone().float())
1142+
except AttributeError:
1143+
new_args.append(deepcopy(arg).float())
11341144
else:
11351145
new_args.append(arg)
11361146
for name, val in self._kwargs.items():
11371147
if hasattr(val, "float"):
1138-
new_kwargs[name] = val.float()
1148+
try:
1149+
new_kwargs[name] = val.clone().float()
1150+
except AttributeError:
1151+
new_kwargs[name] = deepcopy(val).float()
11391152
else:
11401153
new_kwargs[name] = val
11411154
return self.__class__(*new_args, **new_kwargs)
@@ -1148,12 +1161,18 @@ def half(self, device_id=None):
11481161
new_kwargs = {}
11491162
for arg in self._args:
11501163
if hasattr(arg, "half"):
1151-
new_args.append(arg.half())
1164+
try:
1165+
new_args.append(arg.clone().half())
1166+
except AttributeError:
1167+
new_args.append(deepcopy(arg).half())
11521168
else:
11531169
new_args.append(arg)
11541170
for name, val in self._kwargs.items():
11551171
if hasattr(val, "half"):
1156-
new_kwargs[name] = val.half()
1172+
try:
1173+
new_kwargs[name] = val.clone().half()
1174+
except AttributeError:
1175+
new_kwargs[name] = deepcopy(val).half()
11571176
else:
11581177
new_kwargs[name] = val
11591178
return self.__class__(*new_args, **new_kwargs)

0 commit comments

Comments
 (0)