|
3 | 3 | import math |
4 | 4 | import warnings |
5 | 5 | from abc import ABC, abstractmethod |
| 6 | +from copy import deepcopy |
6 | 7 | from typing import Optional, Tuple |
7 | 8 |
|
8 | 9 | import torch |
|
30 | 31 | from ..utils.warnings import NumericalWarning |
31 | 32 | from .lazy_tensor_representation_tree import LazyTensorRepresentationTree |
32 | 33 |
|
| 34 | +_TYPES_DICT = {torch.float: "float", torch.half: "half", torch.double: "double"} |
| 35 | + |
33 | 36 |
|
34 | 37 | class LazyTensor(ABC): |
35 | 38 | r""" |
@@ -1063,19 +1066,7 @@ def double(self, device_id=None): |
1063 | 1066 | """ |
1064 | 1067 | This method operates identically to :func:`torch.Tensor.double`. |
1065 | 1068 | """ |
1066 | | - new_args = [] |
1067 | | - new_kwargs = {} |
1068 | | - for arg in self._args: |
1069 | | - if hasattr(arg, "double"): |
1070 | | - new_args.append(arg.double()) |
1071 | | - else: |
1072 | | - new_args.append(arg) |
1073 | | - for name, val in self._kwargs.items(): |
1074 | | - if hasattr(val, "double"): |
1075 | | - new_kwargs[name] = val.double() |
1076 | | - else: |
1077 | | - new_kwargs[name] = val |
1078 | | - return self.__class__(*new_args, **new_kwargs) |
| 1069 | + return self.type(torch.double) |
1079 | 1070 |
|
1080 | 1071 | @property |
1081 | 1072 | def dtype(self): |
@@ -1122,6 +1113,18 @@ def evaluate_kernel(self): |
1122 | 1113 | """ |
1123 | 1114 | return self.representation_tree()(*self.representation()) |
1124 | 1115 |
|
| 1116 | + def float(self, device_id=None): |
| 1117 | + """ |
| 1118 | + This method operates identically to :func:`torch.Tensor.float`. |
| 1119 | + """ |
| 1120 | + return self.type(torch.float) |
| 1121 | + |
| 1122 | + def half(self, device_id=None): |
| 1123 | + """ |
| 1124 | + This method operates identically to :func:`torch.Tensor.half`. |
| 1125 | + """ |
| 1126 | + return self.type(torch.half) |
| 1127 | + |
1125 | 1128 | def inv_matmul(self, right_tensor, left_tensor=None): |
1126 | 1129 | r""" |
1127 | 1130 | Computes a linear solve (w.r.t self = :math:`A`) with several right hand sides :math:`R`. |
@@ -1924,6 +1927,32 @@ def transpose(self, dim1, dim2): |
1924 | 1927 |
|
1925 | 1928 | return res |
1926 | 1929 |
|
| 1930 | + def type(self, dtype): |
| 1931 | + """ |
| 1932 | + This method operates similarly to :func:`torch.Tensor.type`. |
| 1933 | + """ |
| 1934 | + attr_flag = _TYPES_DICT[dtype] |
| 1935 | + |
| 1936 | + new_args = [] |
| 1937 | + new_kwargs = {} |
| 1938 | + for arg in self._args: |
| 1939 | + if hasattr(arg, attr_flag): |
| 1940 | + try: |
| 1941 | + new_args.append(arg.clone().to(dtype)) |
| 1942 | + except AttributeError: |
| 1943 | + new_args.append(deepcopy(arg).to(dtype)) |
| 1944 | + else: |
| 1945 | + new_args.append(arg) |
| 1946 | + for name, val in self._kwargs.items(): |
| 1947 | + if hasattr(val, attr_flag): |
| 1948 | + try: |
| 1949 | + new_kwargs[name] = val.clone().to(dtype) |
| 1950 | + except AttributeError: |
| 1951 | + new_kwargs[name] = deepcopy(val).to(dtype) |
| 1952 | + else: |
| 1953 | + new_kwargs[name] = val |
| 1954 | + return self.__class__(*new_args, **new_kwargs) |
| 1955 | + |
1927 | 1956 | def unsqueeze(self, dim): |
1928 | 1957 | positive_dim = (self.dim() + dim + 1) if dim < 0 else dim |
1929 | 1958 | if positive_dim > len(self.batch_shape): |
|
0 commit comments