33import math
44import warnings
55from abc import ABC , abstractmethod
6+ from copy import deepcopy
67from typing import Optional , Tuple
78
89import torch
@@ -1067,12 +1068,18 @@ def double(self, device_id=None):
10671068 new_kwargs = {}
10681069 for arg in self ._args :
10691070 if hasattr (arg , "double" ):
1070- new_args .append (arg .double ())
1071+ try :
1072+ new_args .append (arg .clone ().double ())
1073+ except AttributeError :
1074+ new_args .append (deepcopy (arg ).double ())
10711075 else :
10721076 new_args .append (arg )
10731077 for name , val in self ._kwargs .items ():
10741078 if hasattr (val , "double" ):
1075- new_kwargs [name ] = val .double ()
1079+ try :
1080+ new_kwargs [name ] = val .clone ().double ()
1081+ except AttributeError :
1082+ new_kwargs [name ] = deepcopy (val ).double ()
10761083 else :
10771084 new_kwargs [name ] = val
10781085 return self .__class__ (* new_args , ** new_kwargs )
@@ -1130,12 +1137,18 @@ def float(self, device_id=None):
11301137 new_kwargs = {}
11311138 for arg in self ._args :
11321139 if hasattr (arg , "float" ):
1133- new_args .append (arg .float ())
1140+ try :
1141+ new_args .append (arg .clone ().float ())
1142+ except AttributeError :
1143+ new_args .append (deepcopy (arg ).float ())
11341144 else :
11351145 new_args .append (arg )
11361146 for name , val in self ._kwargs .items ():
11371147 if hasattr (val , "float" ):
1138- new_kwargs [name ] = val .float ()
1148+ try :
1149+ new_kwargs [name ] = val .clone ().float ()
1150+ except AttributeError :
1151+ new_kwargs [name ] = deepcopy (val ).float ()
11391152 else :
11401153 new_kwargs [name ] = val
11411154 return self .__class__ (* new_args , ** new_kwargs )
@@ -1148,12 +1161,18 @@ def half(self, device_id=None):
11481161 new_kwargs = {}
11491162 for arg in self ._args :
11501163 if hasattr (arg , "half" ):
1151- new_args .append (arg .half ())
1164+ try :
1165+ new_args .append (arg .clone ().half ())
1166+ except AttributeError :
1167+ new_args .append (deepcopy (arg ).half ())
11521168 else :
11531169 new_args .append (arg )
11541170 for name , val in self ._kwargs .items ():
11551171 if hasattr (val , "half" ):
1156- new_kwargs [name ] = val .half ()
1172+ try :
1173+ new_kwargs [name ] = val .clone ().half ()
1174+ except AttributeError :
1175+ new_kwargs [name ] = deepcopy (val ).half ()
11571176 else :
11581177 new_kwargs [name ] = val
11591178 return self .__class__ (* new_args , ** new_kwargs )
0 commit comments