@@ -1122,6 +1122,42 @@ def evaluate_kernel(self):
11221122 """
11231123 return self .representation_tree ()(* self .representation ())
11241124
1125+ def float (self , device_id = None ):
1126+ """
1127+ This method operates identically to :func:`torch.Tensor.float`.
1128+ """
1129+ new_args = []
1130+ new_kwargs = {}
1131+ for arg in self ._args :
1132+ if hasattr (arg , "float" ):
1133+ new_args .append (arg .float ())
1134+ else :
1135+ new_args .append (arg )
1136+ for name , val in self ._kwargs .items ():
1137+ if hasattr (val , "float" ):
1138+ new_kwargs [name ] = val .float ()
1139+ else :
1140+ new_kwargs [name ] = val
1141+ return self .__class__ (* new_args , ** new_kwargs )
1142+
1143+ def half (self , device_id = None ):
1144+ """
1145+ This method operates identically to :func:`torch.Tensor.half`.
1146+ """
1147+ new_args = []
1148+ new_kwargs = {}
1149+ for arg in self ._args :
1150+ if hasattr (arg , "half" ):
1151+ new_args .append (arg .half ())
1152+ else :
1153+ new_args .append (arg )
1154+ for name , val in self ._kwargs .items ():
1155+ if hasattr (val , "half" ):
1156+ new_kwargs [name ] = val .half ()
1157+ else :
1158+ new_kwargs [name ] = val
1159+ return self .__class__ (* new_args , ** new_kwargs )
1160+
11251161 def inv_matmul (self , right_tensor , left_tensor = None ):
11261162 r"""
11271163 Computes a linear solve (w.r.t self = :math:`A`) with several right hand sides :math:`R`.
0 commit comments