|
31 | 31 | from ..utils.warnings import NumericalWarning |
32 | 32 | from .lazy_tensor_representation_tree import LazyTensorRepresentationTree |
33 | 33 |
|
| 34 | +_TYPES_DICT = {torch.float: "float", torch.half: "half", torch.double: "double"} |
| 35 | + |
34 | 36 |
|
35 | 37 | class LazyTensor(ABC): |
36 | 38 | r""" |
@@ -1064,25 +1066,7 @@ def double(self, device_id=None): |
1064 | 1066 | """ |
1065 | 1067 | This method operates identically to :func:`torch.Tensor.double`. |
1066 | 1068 | """ |
1067 | | - new_args = [] |
1068 | | - new_kwargs = {} |
1069 | | - for arg in self._args: |
1070 | | - if hasattr(arg, "double"): |
1071 | | - try: |
1072 | | - new_args.append(arg.clone().double()) |
1073 | | - except AttributeError: |
1074 | | - new_args.append(deepcopy(arg).double()) |
1075 | | - else: |
1076 | | - new_args.append(arg) |
1077 | | - for name, val in self._kwargs.items(): |
1078 | | - if hasattr(val, "double"): |
1079 | | - try: |
1080 | | - new_kwargs[name] = val.clone().double() |
1081 | | - except AttributeError: |
1082 | | - new_kwargs[name] = deepcopy(val).double() |
1083 | | - else: |
1084 | | - new_kwargs[name] = val |
1085 | | - return self.__class__(*new_args, **new_kwargs) |
| 1069 | + return self.type(torch.double) |
1086 | 1070 |
|
1087 | 1071 | @property |
1088 | 1072 | def dtype(self): |
@@ -1133,49 +1117,13 @@ def float(self, device_id=None): |
1133 | 1117 | """ |
1134 | 1118 | This method operates identically to :func:`torch.Tensor.float`. |
1135 | 1119 | """ |
1136 | | - new_args = [] |
1137 | | - new_kwargs = {} |
1138 | | - for arg in self._args: |
1139 | | - if hasattr(arg, "float"): |
1140 | | - try: |
1141 | | - new_args.append(arg.clone().float()) |
1142 | | - except AttributeError: |
1143 | | - new_args.append(deepcopy(arg).float()) |
1144 | | - else: |
1145 | | - new_args.append(arg) |
1146 | | - for name, val in self._kwargs.items(): |
1147 | | - if hasattr(val, "float"): |
1148 | | - try: |
1149 | | - new_kwargs[name] = val.clone().float() |
1150 | | - except AttributeError: |
1151 | | - new_kwargs[name] = deepcopy(val).float() |
1152 | | - else: |
1153 | | - new_kwargs[name] = val |
1154 | | - return self.__class__(*new_args, **new_kwargs) |
| 1120 | + return self.type(torch.float) |
1155 | 1121 |
|
1156 | 1122 | def half(self, device_id=None): |
1157 | 1123 | """ |
1158 | 1124 | This method operates identically to :func:`torch.Tensor.half`. |
1159 | 1125 | """ |
1160 | | - new_args = [] |
1161 | | - new_kwargs = {} |
1162 | | - for arg in self._args: |
1163 | | - if hasattr(arg, "half"): |
1164 | | - try: |
1165 | | - new_args.append(arg.clone().half()) |
1166 | | - except AttributeError: |
1167 | | - new_args.append(deepcopy(arg).half()) |
1168 | | - else: |
1169 | | - new_args.append(arg) |
1170 | | - for name, val in self._kwargs.items(): |
1171 | | - if hasattr(val, "half"): |
1172 | | - try: |
1173 | | - new_kwargs[name] = val.clone().half() |
1174 | | - except AttributeError: |
1175 | | - new_kwargs[name] = deepcopy(val).half() |
1176 | | - else: |
1177 | | - new_kwargs[name] = val |
1178 | | - return self.__class__(*new_args, **new_kwargs) |
| 1126 | + return self.type(torch.half) |
1179 | 1127 |
|
1180 | 1128 | def inv_matmul(self, right_tensor, left_tensor=None): |
1181 | 1129 | r""" |
@@ -1980,14 +1928,30 @@ def transpose(self, dim1, dim2): |
1980 | 1928 | return res |
1981 | 1929 |
|
1982 | 1930 | def type(self, dtype): |
1983 | | - if dtype == torch.float: |
1984 | | - return self.float() |
1985 | | - elif dtype == torch.double: |
1986 | | - return self.double() |
1987 | | - elif dtype == torch.half: |
1988 | | - return self.half() |
1989 | | - else: |
1990 | | - raise RuntimeError("Dtype", dtype, "not found.") |
| 1931 | + """ |
| 1932 | + This method operates similarly to :func:`torch.Tensor.type`. |
| 1933 | + """ |
| 1934 | + attr_flag = _TYPES_DICT[dtype] |
| 1935 | + |
| 1936 | + new_args = [] |
| 1937 | + new_kwargs = {} |
| 1938 | + for arg in self._args: |
| 1939 | + if hasattr(arg, attr_flag): |
| 1940 | + try: |
| 1941 | + new_args.append(arg.clone().to(dtype)) |
| 1942 | + except AttributeError: |
| 1943 | + new_args.append(deepcopy(arg).to(dtype)) |
| 1944 | + else: |
| 1945 | + new_args.append(arg) |
| 1946 | + for name, val in self._kwargs.items(): |
| 1947 | + if hasattr(val, attr_flag): |
| 1948 | + try: |
| 1949 | + new_kwargs[name] = val.clone().to(dtype) |
| 1950 | + except AttributeError: |
| 1951 | + new_kwargs[name] = deepcopy(val).to(dtype) |
| 1952 | + else: |
| 1953 | + new_kwargs[name] = val |
| 1954 | + return self.__class__(*new_args, **new_kwargs) |
1991 | 1955 |
|
1992 | 1956 | def unsqueeze(self, dim): |
1993 | 1957 | positive_dim = (self.dim() + dim + 1) if dim < 0 else dim |
|
0 commit comments