Skip to content

Commit aca9fbe

Browse files
authored
unify PythonKey (i.e. ProxyTensor) tracer with the one in core (#853)
* unify tracer with the one in core * modified test * fix lint issues * fixed some things
1 parent 96dead4 commit aca9fbe

File tree

3 files changed

+5
-258
lines changed

3 files changed

+5
-258
lines changed

functorch/_src/eager_transforms.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1325,7 +1325,7 @@ def _register_python_decomposition_vmap(decomp):
13251325
raise RuntimeError(f"could not find decomposition for {decomp}")
13261326

13271327

1328-
_register_jit_decomposition(torch.ops.aten.trace.default)
1328+
_register_jit_decomposition(torch.ops.aten.trace.default, use_python=True)
13291329
_register_jit_decomposition(torch.ops.aten.nll_loss_backward.default)
13301330
_register_jit_decomposition(torch.ops.aten.nll_loss2d_backward.default)
13311331
_register_jit_decomposition(torch.ops.aten._log_softmax_backward_data.default)

functorch/_src/python_key.py

Lines changed: 3 additions & 203 deletions
Original file line numberDiff line numberDiff line change
@@ -3,207 +3,7 @@
33
#
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
6-
import functools
7-
from typing import Any, Dict, Optional, Tuple, Callable, Union
8-
import torch
9-
from torch._C import _disabled_torch_function_impl
10-
import torch.utils._pytree as pytree
11-
from torch.fx import Tracer, GraphModule
12-
import torch.fx as fx
13-
from torch.fx.passes.shape_prop import _extract_tensor_metadata
14-
from contextlib import contextmanager
6+
__all__ = ["make_fx", "ProxyTensor", "dispatch_trace", "PythonKeyTracer", "pythonkey_decompose"]
7+
from torch.fx.experimental.proxy_tensor import make_fx, ProxyTensor, dispatch_trace, PythonKeyTracer, decompose
158

16-
aten = torch.ops.aten
17-
18-
CURRENT_DECOMPOSITION_TABLE = {}
19-
20-
21-
@contextmanager
22-
def no_dispatch():
23-
guard = torch._C._DisableTorchDispatch()
24-
try:
25-
yield
26-
finally:
27-
del guard
28-
29-
30-
@contextmanager
31-
def pythonkey_decompose(decomposition_table):
32-
global CURRENT_DECOMPOSITION_TABLE
33-
CURRENT_DECOMPOSITION_TABLE = decomposition_table
34-
try:
35-
yield CURRENT_DECOMPOSITION_TABLE
36-
finally:
37-
CURRENT_DECOMPOSITION_TABLE = {}
38-
39-
40-
class PythonTensor(torch.Tensor):
41-
elem: torch.Tensor
42-
43-
__slots__ = ['elem', 'proxy']
44-
45-
@staticmethod
46-
def __new__(cls, elem, proxy):
47-
# Wrapping something in PythonTensor implicitly detaches
48-
# gradients. If something required grad, we will collect it as if it
49-
# were a leaf. A consequence of detaching in this way is you
50-
# need to maintain a parameter cache when translating tensors
51-
# into PythonTensor, so you don't create multiple copies of
52-
# a gradient (they are aliased, but they would count as independent
53-
# leaves). An alternate strategy would be to avoid implicitly
54-
# detaching and instead "catch" gradients as they exit the
55-
# PythonTensor boundary.
56-
# assert not elem.requires_grad or not torch.is_grad_enabled()
57-
58-
r = torch.Tensor._make_subclass(cls, elem, elem.requires_grad)
59-
r.proxy = proxy
60-
if elem.is_sparse:
61-
proxy.node.meta['tensor_meta'] = {}
62-
else:
63-
proxy.node.meta['tensor_meta'] = _extract_tensor_metadata(r)
64-
return r
65-
66-
def __repr__(self):
67-
with no_dispatch():
68-
return f"PythonTensor({self.as_subclass(torch.Tensor)})"
69-
70-
__torch_function__ = _disabled_torch_function_impl
71-
72-
def __deepcopy__(self, memo):
73-
return self.clone()
74-
75-
@classmethod
76-
def __torch_dispatch__(cls, func_overload, types, args=(), kwargs=None):
77-
func = func_overload.overloadpacket
78-
if func_overload in CURRENT_DECOMPOSITION_TABLE:
79-
return CURRENT_DECOMPOSITION_TABLE[func_overload](*args, **kwargs)
80-
# Commenting this out for now since it causes some spurious failures (such as error checking)
81-
# if func == aten._local_scalar_dense:
82-
# raise RuntimeError("It appears that you're trying to get value out of a tracing tensor - erroring out! "
83-
# "It's likely that this is caused by data-dependent control flow or similar.")
84-
85-
def unwrap_proxy(e):
86-
return e.proxy if isinstance(e, PythonTensor) else e
87-
88-
proxy_args = pytree.tree_map(unwrap_proxy, args)
89-
proxy_kwargs = pytree.tree_map(unwrap_proxy, kwargs)
90-
91-
proxy_out = func(*proxy_args, **proxy_kwargs)
92-
93-
# Kind of a hacky way to test if an op is in-place or not
94-
if func.__name__[-1] == "_" and func.__name__[0] != "_":
95-
args[0].proxy = proxy_out
96-
proxy_out.node.meta['tensor_meta'] = _extract_tensor_metadata(args[0])
97-
98-
with no_dispatch():
99-
real_out = func_overload(*args, **kwargs)
100-
101-
def wrap_with_proxy(e, proxy):
102-
# Some ops (like native_batch_norm_backward) return undefined tensors that get
103-
# converted into None in python.
104-
# As the function signature expects tensors, if we directly return these None
105-
# tensors back to C++, we'll error.
106-
if e is None:
107-
e = torch.empty(())
108-
if type(e) == torch.Tensor:
109-
return PythonTensor(e, proxy)
110-
else:
111-
return e
112-
if isinstance(real_out, tuple):
113-
return tuple(wrap_with_proxy(e, proxy_out[idx]) for idx, e in enumerate(real_out))
114-
elif isinstance(real_out, list):
115-
return [wrap_with_proxy(e, proxy_out[idx]) for idx, e in enumerate(real_out)]
116-
elif isinstance(real_out, torch.Tensor):
117-
return wrap_with_proxy(real_out, proxy_out)
118-
else:
119-
return real_out
120-
121-
122-
class PythonKeyTracer(Tracer):
123-
def __init__(self):
124-
super().__init__()
125-
126-
def call_module(
127-
self, m: torch.nn.Module, forward: Callable[..., Any], args: Tuple[Any, ...], kwargs: Dict[str, Any]
128-
) -> Any:
129-
return forward(*args, **kwargs)
130-
131-
def _module_getattr(self, attr, attr_val, parameter_proxy_cache):
132-
if isinstance(attr_val, torch.nn.Parameter):
133-
for n, p in self.root.named_parameters():
134-
if attr_val is p:
135-
if n not in parameter_proxy_cache:
136-
proxy = self.create_proxy('get_attr', n, (), {})
137-
parameter_proxy_cache[n] = PythonTensor(attr_val, proxy)
138-
return parameter_proxy_cache[n]
139-
return attr_val
140-
return attr_val
141-
142-
# We need to do this so that parameters entering the `make_fx` context have
143-
# a reference to them (and also have requires_grad set on them correctly
144-
# I'm not actually sure if this is the right thing to do ...
145-
def create_arg(self, a: Any):
146-
if isinstance(a, torch.nn.Parameter):
147-
for n, p in self.root.named_parameters():
148-
if a is p:
149-
return self.create_node('get_attr', n, (), {})
150-
qualname: Optional[str] = None
151-
152-
if not qualname:
153-
i = 0
154-
while True:
155-
qualname = f'_param_constant{i}'
156-
if not hasattr(self.root, qualname):
157-
break
158-
i += 1
159-
setattr(self.root, qualname, a)
160-
161-
return self.create_node('get_attr', qualname, (), {})
162-
return super().create_arg(a)
163-
164-
165-
def pythonkey_trace(
166-
root: Union[torch.nn.Module, Callable], concrete_args: Optional[Dict[str, Any]] = None
167-
) -> GraphModule:
168-
tracer = PythonKeyTracer()
169-
graph = tracer.trace(root, concrete_args)
170-
name = root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__
171-
return GraphModule(tracer.root, graph, name)
172-
173-
174-
def wrap_key(f, inps):
175-
flat_inps, inp_spec = pytree.tree_flatten(inps)
176-
177-
@functools.wraps(f)
178-
def wrapped(*args):
179-
flat_args, args_spec = pytree.tree_flatten(args)
180-
assert(len(flat_args) == len(flat_inps))
181-
for idx, arg in enumerate(flat_args):
182-
if isinstance(flat_inps[idx], torch.Tensor):
183-
flat_args[idx] = PythonTensor(flat_inps[idx], arg)
184-
else:
185-
flat_args[idx] = flat_inps[idx]
186-
187-
tree_args = pytree.tree_unflatten(flat_args, args_spec)
188-
out = f(*tree_args)
189-
flat_outs, out_spec = pytree.tree_flatten(out)
190-
for idx in range(len(flat_outs)):
191-
if isinstance(flat_outs[idx], torch.Tensor) and isinstance(flat_outs[idx], PythonTensor):
192-
flat_outs[idx] = flat_outs[idx].proxy
193-
return pytree.tree_unflatten(flat_outs, out_spec)
194-
195-
return wrapped
196-
197-
198-
def make_fx(f, decomposition_table=None):
199-
if decomposition_table is None:
200-
decomposition_table = {}
201-
202-
@functools.wraps(f)
203-
def wrapped(*args):
204-
phs = pytree.tree_map(lambda x: fx.PH, args)
205-
with pythonkey_decompose(decomposition_table):
206-
t = pythonkey_trace(wrap_key(f, args), concrete_args=tuple(phs))
207-
return t
208-
209-
return wrapped
9+
pythonkey_decompose = decompose

test/test_pythonkey.py

Lines changed: 1 addition & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -191,59 +191,6 @@ def f(x):
191191
self.assertEqual(grads, grads2)
192192

193193

194-
make_fx_failures = {
195-
xfail('allclose'),
196-
xfail('nn.functional.dropout'),
197-
xfail('linalg.eigvals'),
198-
xfail('nn.functional.max_pool1d', device_type='cpu'), # precision problems?
199-
xfail('randn_like'), # randomness
200-
xfail('rand_like'), # randomness
201-
xfail('randint_like'), # randomness
202-
skip('new_empty'), # nondeterministic
203-
skip('empty_like'), # nondeterministic
204-
skip('linalg.lstsq', 'grad_oriented'), # flaky
205-
xfail('normal', '', device_type='cpu'),
206-
xfail('normal', 'number_mean', device_type='cpu'),
207-
xfail('multinomial', device_type='cpu'),
208-
xfail('nn.functional.feature_alpha_dropout', 'with_train', device_type='cpu'),
209-
xfail('bernoulli', device_type='cpu'),
210-
xfail('nn.functional.dropout2d', device_type='cpu'),
211-
skip('nn.functional.max_unpool1d', '', device_type='cpu'), # flaky
212-
skip('nn.functional.max_unpool2d', '', device_type='cpu'), # flaky
213-
skip('nn.functional.max_unpool3d', '', device_type='cpu'), # flaky
214-
skip('linalg.lstsq'), # flaky, probably just a precision issue
215-
xfail('histogram'),
216-
xfail('scatter')
217-
}
218-
219-
220-
class TestPythonKeyOperatorsOpInfo(TestCase):
221-
@ops(functorch_lagging_op_db + additional_op_db, allowed_dtypes=(torch.float,))
222-
@skipOps('TestPythonKeyOperatorsOpInfo', 'test_make_fx_exhaustive', make_fx_failures
223-
)
224-
def test_make_fx_exhaustive(self, device, dtype, op):
225-
226-
def f(args, kwargs):
227-
return op.op(*args, **kwargs)
228-
sample_inputs_itr = op.sample_inputs(device, dtype, requires_grad=False)
229-
new_f = None
230-
for sample_input in sample_inputs_itr:
231-
args = [sample_input.input] + list(sample_input.args)
232-
kwargs = sample_input.kwargs
233-
234-
new_f = make_fx(f)(args, kwargs)
235-
for arg in args:
236-
if isinstance(arg, torch.Tensor) and arg.dtype == torch.float:
237-
arg.uniform_(0, 1)
238-
try:
239-
old_out = f(args, kwargs)
240-
except Exception:
241-
continue
242-
new_out = new_f(args, kwargs)
243-
self.assertEqual(new_out, old_out)
244-
pass
245-
246-
247194
def _outs_and_grads(fn, inps):
248195
outs = fn(*inps)
249196
for out in pytree.tree_flatten(outs)[0]:
@@ -375,6 +322,7 @@ class TestEagerFusionOpInfo(TestCase):
375322
xfail('diag_embed'),
376323
xfail('linalg.householder_product'),
377324
xfail('logit'),
325+
xfail('logdet'),
378326
xfail('matrix_exp'),
379327
xfail('trapezoid'),
380328
xfail('trapz'),
@@ -604,7 +552,6 @@ def forward(self, x, y):
604552
globals(),
605553
only_for=only_for,
606554
)
607-
instantiate_device_type_tests(TestPythonKeyOperatorsOpInfo, globals(), only_for=only_for)
608555
instantiate_device_type_tests(TestEagerFusionOpInfo, globals(), only_for=only_for)
609556

610557

0 commit comments

Comments
 (0)