|
3 | 3 | #
|
4 | 4 | # This source code is licensed under the BSD-style license found in the
|
5 | 5 | # LICENSE file in the root directory of this source tree.
|
6 |
| -import functools |
7 |
| -from typing import Any, Dict, Optional, Tuple, Callable, Union |
8 |
| -import torch |
9 |
| -from torch._C import _disabled_torch_function_impl |
10 |
| -import torch.utils._pytree as pytree |
11 |
| -from torch.fx import Tracer, GraphModule |
12 |
| -import torch.fx as fx |
13 |
| -from torch.fx.passes.shape_prop import _extract_tensor_metadata |
14 |
| -from contextlib import contextmanager |
| 6 | +__all__ = ["make_fx", "ProxyTensor", "dispatch_trace", "PythonKeyTracer", "pythonkey_decompose"] |
| 7 | +from torch.fx.experimental.proxy_tensor import make_fx, ProxyTensor, dispatch_trace, PythonKeyTracer, decompose |
15 | 8 |
|
16 |
| -aten = torch.ops.aten |
17 |
| - |
18 |
| -CURRENT_DECOMPOSITION_TABLE = {} |
19 |
| - |
20 |
| - |
21 |
| -@contextmanager |
22 |
| -def no_dispatch(): |
23 |
| - guard = torch._C._DisableTorchDispatch() |
24 |
| - try: |
25 |
| - yield |
26 |
| - finally: |
27 |
| - del guard |
28 |
| - |
29 |
| - |
30 |
| -@contextmanager |
31 |
| -def pythonkey_decompose(decomposition_table): |
32 |
| - global CURRENT_DECOMPOSITION_TABLE |
33 |
| - CURRENT_DECOMPOSITION_TABLE = decomposition_table |
34 |
| - try: |
35 |
| - yield CURRENT_DECOMPOSITION_TABLE |
36 |
| - finally: |
37 |
| - CURRENT_DECOMPOSITION_TABLE = {} |
38 |
| - |
39 |
| - |
40 |
| -class PythonTensor(torch.Tensor): |
41 |
| - elem: torch.Tensor |
42 |
| - |
43 |
| - __slots__ = ['elem', 'proxy'] |
44 |
| - |
45 |
| - @staticmethod |
46 |
| - def __new__(cls, elem, proxy): |
47 |
| - # Wrapping something in PythonTensor implicitly detaches |
48 |
| - # gradients. If something required grad, we will collect it as if it |
49 |
| - # were a leaf. A consequence of detaching in this way is you |
50 |
| - # need to maintain a parameter cache when translating tensors |
51 |
| - # into PythonTensor, so you don't create multiple copies of |
52 |
| - # a gradient (they are aliased, but they would count as independent |
53 |
| - # leaves). An alternate strategy would be to avoid implicitly |
54 |
| - # detaching and instead "catch" gradients as they exit the |
55 |
| - # PythonTensor boundary. |
56 |
| - # assert not elem.requires_grad or not torch.is_grad_enabled() |
57 |
| - |
58 |
| - r = torch.Tensor._make_subclass(cls, elem, elem.requires_grad) |
59 |
| - r.proxy = proxy |
60 |
| - if elem.is_sparse: |
61 |
| - proxy.node.meta['tensor_meta'] = {} |
62 |
| - else: |
63 |
| - proxy.node.meta['tensor_meta'] = _extract_tensor_metadata(r) |
64 |
| - return r |
65 |
| - |
66 |
| - def __repr__(self): |
67 |
| - with no_dispatch(): |
68 |
| - return f"PythonTensor({self.as_subclass(torch.Tensor)})" |
69 |
| - |
70 |
| - __torch_function__ = _disabled_torch_function_impl |
71 |
| - |
72 |
| - def __deepcopy__(self, memo): |
73 |
| - return self.clone() |
74 |
| - |
75 |
| - @classmethod |
76 |
| - def __torch_dispatch__(cls, func_overload, types, args=(), kwargs=None): |
77 |
| - func = func_overload.overloadpacket |
78 |
| - if func_overload in CURRENT_DECOMPOSITION_TABLE: |
79 |
| - return CURRENT_DECOMPOSITION_TABLE[func_overload](*args, **kwargs) |
80 |
| - # Commenting this out for now since it causes some spurious failures (such as error checking) |
81 |
| - # if func == aten._local_scalar_dense: |
82 |
| - # raise RuntimeError("It appears that you're trying to get value out of a tracing tensor - erroring out! " |
83 |
| - # "It's likely that this is caused by data-dependent control flow or similar.") |
84 |
| - |
85 |
| - def unwrap_proxy(e): |
86 |
| - return e.proxy if isinstance(e, PythonTensor) else e |
87 |
| - |
88 |
| - proxy_args = pytree.tree_map(unwrap_proxy, args) |
89 |
| - proxy_kwargs = pytree.tree_map(unwrap_proxy, kwargs) |
90 |
| - |
91 |
| - proxy_out = func(*proxy_args, **proxy_kwargs) |
92 |
| - |
93 |
| - # Kind of a hacky way to test if an op is in-place or not |
94 |
| - if func.__name__[-1] == "_" and func.__name__[0] != "_": |
95 |
| - args[0].proxy = proxy_out |
96 |
| - proxy_out.node.meta['tensor_meta'] = _extract_tensor_metadata(args[0]) |
97 |
| - |
98 |
| - with no_dispatch(): |
99 |
| - real_out = func_overload(*args, **kwargs) |
100 |
| - |
101 |
| - def wrap_with_proxy(e, proxy): |
102 |
| - # Some ops (like native_batch_norm_backward) return undefined tensors that get |
103 |
| - # converted into None in python. |
104 |
| - # As the function signature expects tensors, if we directly return these None |
105 |
| - # tensors back to C++, we'll error. |
106 |
| - if e is None: |
107 |
| - e = torch.empty(()) |
108 |
| - if type(e) == torch.Tensor: |
109 |
| - return PythonTensor(e, proxy) |
110 |
| - else: |
111 |
| - return e |
112 |
| - if isinstance(real_out, tuple): |
113 |
| - return tuple(wrap_with_proxy(e, proxy_out[idx]) for idx, e in enumerate(real_out)) |
114 |
| - elif isinstance(real_out, list): |
115 |
| - return [wrap_with_proxy(e, proxy_out[idx]) for idx, e in enumerate(real_out)] |
116 |
| - elif isinstance(real_out, torch.Tensor): |
117 |
| - return wrap_with_proxy(real_out, proxy_out) |
118 |
| - else: |
119 |
| - return real_out |
120 |
| - |
121 |
| - |
122 |
| -class PythonKeyTracer(Tracer): |
123 |
| - def __init__(self): |
124 |
| - super().__init__() |
125 |
| - |
126 |
| - def call_module( |
127 |
| - self, m: torch.nn.Module, forward: Callable[..., Any], args: Tuple[Any, ...], kwargs: Dict[str, Any] |
128 |
| - ) -> Any: |
129 |
| - return forward(*args, **kwargs) |
130 |
| - |
131 |
| - def _module_getattr(self, attr, attr_val, parameter_proxy_cache): |
132 |
| - if isinstance(attr_val, torch.nn.Parameter): |
133 |
| - for n, p in self.root.named_parameters(): |
134 |
| - if attr_val is p: |
135 |
| - if n not in parameter_proxy_cache: |
136 |
| - proxy = self.create_proxy('get_attr', n, (), {}) |
137 |
| - parameter_proxy_cache[n] = PythonTensor(attr_val, proxy) |
138 |
| - return parameter_proxy_cache[n] |
139 |
| - return attr_val |
140 |
| - return attr_val |
141 |
| - |
142 |
| - # We need to do this so that parameters entering the `make_fx` context have |
143 |
| - # a reference to them (and also have requires_grad set on them correctly |
144 |
| - # I'm not actually sure if this is the right thing to do ... |
145 |
| - def create_arg(self, a: Any): |
146 |
| - if isinstance(a, torch.nn.Parameter): |
147 |
| - for n, p in self.root.named_parameters(): |
148 |
| - if a is p: |
149 |
| - return self.create_node('get_attr', n, (), {}) |
150 |
| - qualname: Optional[str] = None |
151 |
| - |
152 |
| - if not qualname: |
153 |
| - i = 0 |
154 |
| - while True: |
155 |
| - qualname = f'_param_constant{i}' |
156 |
| - if not hasattr(self.root, qualname): |
157 |
| - break |
158 |
| - i += 1 |
159 |
| - setattr(self.root, qualname, a) |
160 |
| - |
161 |
| - return self.create_node('get_attr', qualname, (), {}) |
162 |
| - return super().create_arg(a) |
163 |
| - |
164 |
| - |
165 |
| -def pythonkey_trace( |
166 |
| - root: Union[torch.nn.Module, Callable], concrete_args: Optional[Dict[str, Any]] = None |
167 |
| -) -> GraphModule: |
168 |
| - tracer = PythonKeyTracer() |
169 |
| - graph = tracer.trace(root, concrete_args) |
170 |
| - name = root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__ |
171 |
| - return GraphModule(tracer.root, graph, name) |
172 |
| - |
173 |
| - |
174 |
| -def wrap_key(f, inps): |
175 |
| - flat_inps, inp_spec = pytree.tree_flatten(inps) |
176 |
| - |
177 |
| - @functools.wraps(f) |
178 |
| - def wrapped(*args): |
179 |
| - flat_args, args_spec = pytree.tree_flatten(args) |
180 |
| - assert(len(flat_args) == len(flat_inps)) |
181 |
| - for idx, arg in enumerate(flat_args): |
182 |
| - if isinstance(flat_inps[idx], torch.Tensor): |
183 |
| - flat_args[idx] = PythonTensor(flat_inps[idx], arg) |
184 |
| - else: |
185 |
| - flat_args[idx] = flat_inps[idx] |
186 |
| - |
187 |
| - tree_args = pytree.tree_unflatten(flat_args, args_spec) |
188 |
| - out = f(*tree_args) |
189 |
| - flat_outs, out_spec = pytree.tree_flatten(out) |
190 |
| - for idx in range(len(flat_outs)): |
191 |
| - if isinstance(flat_outs[idx], torch.Tensor) and isinstance(flat_outs[idx], PythonTensor): |
192 |
| - flat_outs[idx] = flat_outs[idx].proxy |
193 |
| - return pytree.tree_unflatten(flat_outs, out_spec) |
194 |
| - |
195 |
| - return wrapped |
196 |
| - |
197 |
| - |
198 |
| -def make_fx(f, decomposition_table=None): |
199 |
| - if decomposition_table is None: |
200 |
| - decomposition_table = {} |
201 |
| - |
202 |
| - @functools.wraps(f) |
203 |
| - def wrapped(*args): |
204 |
| - phs = pytree.tree_map(lambda x: fx.PH, args) |
205 |
| - with pythonkey_decompose(decomposition_table): |
206 |
| - t = pythonkey_trace(wrap_key(f, args), concrete_args=tuple(phs)) |
207 |
| - return t |
208 |
| - |
209 |
| - return wrapped |
| 9 | +pythonkey_decompose = decompose |
0 commit comments