|
8 | 8 |
|
9 | 9 | from __future__ import annotations |
10 | 10 |
|
11 | | -from typing import Any, cast |
| 11 | +try: # noqa: C901 |
| 12 | + from torch._higher_order_ops.executorch_call_delegate import ( |
| 13 | + executorch_call_delegate as executorch_call_delegate, |
| 14 | + get_lowered_module_name as get_lowered_module_name, |
| 15 | + is_lowered_module as is_lowered_module, |
| 16 | + ) |
| 17 | + |
| 18 | +except ImportError: |
12 | 19 |
|
13 | | -import torch |
14 | | -import torch.utils._pytree as pytree |
15 | | -from torch._ops import HigherOrderOperator |
16 | | -from torch._subclasses.fake_tensor import FakeTensorMode |
17 | | -from torch.fx.experimental.proxy_tensor import ( |
18 | | - disable_proxy_modes_tracing, |
19 | | - get_proxy_slot, |
20 | | - ProxyTorchDispatchMode, |
21 | | - track_tensor_tree, |
22 | | -) |
23 | | -from torch.utils._pytree import tree_flatten |
| 20 | + # TODO: Delete this code once pytorch pin advances |
24 | 21 |
|
| 22 | + from typing import Any, cast |
25 | 23 |
|
26 | | -executorch_call_delegate = HigherOrderOperator("executorch_call_delegate") |
27 | | -executorch_call_delegate.fallthrough(torch._C.DispatchKey.PythonDispatcher) |
28 | | -executorch_call_delegate.fallthrough(torch._C.DispatchKey.PythonTLSSnapshot) |
29 | | -executorch_call_delegate.fallthrough(torch._C.DispatchKey.ADInplaceOrView) |
30 | | -executorch_call_delegate.fallthrough(torch._C.DispatchKey.AutocastCPU) |
| 24 | + import torch |
| 25 | + import torch.utils._pytree as pytree |
| 26 | + from torch._ops import HigherOrderOperator |
| 27 | + from torch._subclasses.fake_tensor import FakeTensorMode |
| 28 | + from torch.fx.experimental.proxy_tensor import ( |
| 29 | + disable_proxy_modes_tracing, |
| 30 | + get_proxy_slot, |
| 31 | + ProxyTorchDispatchMode, |
| 32 | + track_tensor_tree, |
| 33 | + ) |
| 34 | + from torch.utils._pytree import tree_flatten |
31 | 35 |
|
32 | | -LOWERED_BACKEND_MODULE_TYPE = "LoweredBackendModule" |
| 36 | + executorch_call_delegate = HigherOrderOperator("executorch_call_delegate") |
| 37 | + executorch_call_delegate.fallthrough(torch._C.DispatchKey.PythonDispatcher) |
| 38 | + executorch_call_delegate.fallthrough(torch._C.DispatchKey.PythonTLSSnapshot) |
| 39 | + executorch_call_delegate.fallthrough(torch._C.DispatchKey.ADInplaceOrView) |
| 40 | + executorch_call_delegate.fallthrough(torch._C.DispatchKey.AutocastCPU) |
33 | 41 |
|
| 42 | + LOWERED_BACKEND_MODULE_TYPE = "LoweredBackendModule" |
34 | 43 |
|
35 | | -# pyre-ignore |
36 | | -def trace_call_delegate(proxy_mode, func_overload, lowered_module, *args): |
37 | 44 | # pyre-ignore |
38 | | - def _unwrap_proxy(e): |
39 | | - if not isinstance(e, (torch.Tensor, torch.SymInt, torch.SymFloat)): |
40 | | - return e |
41 | | - return get_proxy_slot( |
42 | | - cast(torch.Tensor, e), proxy_mode.tracer, e, lambda e: e.proxy |
| 45 | + def trace_call_delegate(proxy_mode, func_overload, lowered_module, *args): |
| 46 | + # pyre-ignore |
| 47 | + def _unwrap_proxy(e): |
| 48 | + if not isinstance(e, (torch.Tensor, torch.SymInt, torch.SymFloat)): |
| 49 | + return e |
| 50 | + return get_proxy_slot( |
| 51 | + cast(torch.Tensor, e), proxy_mode.tracer, e, lambda e: e.proxy |
| 52 | + ) |
| 53 | + |
| 54 | + if not is_lowered_module(lowered_module): |
| 55 | + raise ValueError( |
| 56 | + "executorch_call_delegate()'s first argument must be a LoweredBackendModule" |
| 57 | + ) |
| 58 | + |
| 59 | + with disable_proxy_modes_tracing(): |
| 60 | + out = call_delegate_cpu(lowered_module, *args) |
| 61 | + |
| 62 | + get_lowered_module_name(proxy_mode.tracer.root, lowered_module) |
| 63 | + |
| 64 | + node_args = (lowered_module, *args) |
| 65 | + proxy_args = pytree.tree_map(_unwrap_proxy, node_args) |
| 66 | + out_proxy = proxy_mode.tracer.create_proxy( |
| 67 | + "call_function", |
| 68 | + func_overload, |
| 69 | + proxy_args, |
| 70 | + {}, |
| 71 | + name="executorch_call_delegate", |
43 | 72 | ) |
44 | | - |
45 | | - if not is_lowered_module(lowered_module): |
46 | | - raise ValueError( |
47 | | - "executorch_call_delegate()'s first argument must be a LoweredBackendModule" |
| 73 | + return track_tensor_tree( |
| 74 | + out, out_proxy, constant=None, tracer=proxy_mode.tracer |
48 | 75 | ) |
49 | 76 |
|
50 | | - with disable_proxy_modes_tracing(): |
51 | | - out = call_delegate_cpu(lowered_module, *args) |
52 | | - |
53 | | - get_lowered_module_name(proxy_mode.tracer.root, lowered_module) |
54 | | - |
55 | | - node_args = (lowered_module, *args) |
56 | | - proxy_args = pytree.tree_map(_unwrap_proxy, node_args) |
57 | | - out_proxy = proxy_mode.tracer.create_proxy( |
58 | | - "call_function", func_overload, proxy_args, {}, name="executorch_call_delegate" |
59 | | - ) |
60 | | - return track_tensor_tree(out, out_proxy, constant=None, tracer=proxy_mode.tracer) |
61 | | - |
62 | | - |
63 | | -@executorch_call_delegate.py_impl(torch._C.DispatchKey.CompositeExplicitAutograd) |
64 | | -# pyre-ignore |
65 | | -def call_delegate_cpu(lowered_module, *args): |
66 | | - # FX creates this immutable_dict/list concept. Get rid of this. |
67 | | - map_types = { |
68 | | - torch.fx.immutable_collections.immutable_dict: dict, |
69 | | - torch.fx.immutable_collections.immutable_list: list, |
70 | | - } |
71 | | - new_args = pytree.tree_map_only( |
72 | | - tuple(map_types.keys()), |
73 | | - lambda a: map_types[type(a)](a), |
74 | | - args, |
75 | | - lambda a: isinstance(a, tuple(map_types.keys())), |
76 | | - ) |
77 | | - return lowered_module.original_module.module()(*new_args) |
| 77 | + @executorch_call_delegate.py_impl(torch._C.DispatchKey.CompositeExplicitAutograd) |
| 78 | + # pyre-ignore |
| 79 | + def call_delegate_cpu(lowered_module, *args): |
| 80 | + # FX creates this immutable_dict/list concept. Get rid of this. |
| 81 | + map_types = { |
| 82 | + torch.fx.immutable_collections.immutable_dict: dict, |
| 83 | + torch.fx.immutable_collections.immutable_list: list, |
| 84 | + } |
| 85 | + new_args = pytree.tree_map_only( |
| 86 | + tuple(map_types.keys()), |
| 87 | + lambda a: map_types[type(a)](a), |
| 88 | + args, |
| 89 | + lambda a: isinstance(a, tuple(map_types.keys())), |
| 90 | + ) |
| 91 | + return lowered_module.original_module.module()(*new_args) |
78 | 92 |
|
| 93 | + @executorch_call_delegate.py_impl(torch._C.DispatchKey.Autograd) |
| 94 | + # pyre-ignore |
| 95 | + def call_delegate_autograd(lowered_module, *args): |
| 96 | + # TODO: support autograd |
| 97 | + flat_operands, _ = tree_flatten([lowered_module, *args]) |
| 98 | + requires_grad = any( |
| 99 | + f.requires_grad for f in flat_operands if isinstance(f, torch.Tensor) |
| 100 | + ) |
79 | 101 |
|
80 | | -@executorch_call_delegate.py_impl(torch._C.DispatchKey.Autograd) |
81 | | -# pyre-ignore |
82 | | -def call_delegate_autograd(lowered_module, *args): |
83 | | - # TODO: support autograd |
84 | | - flat_operands, _ = tree_flatten([lowered_module, *args]) |
85 | | - requires_grad = any( |
86 | | - f.requires_grad for f in flat_operands if isinstance(f, torch.Tensor) |
87 | | - ) |
| 102 | + with torch._C._ExcludeDispatchKeyGuard( |
| 103 | + torch._C.DispatchKeySet(torch._C.DispatchKey.AutogradCPU) |
| 104 | + ): |
| 105 | + res = executorch_call_delegate(lowered_module, *args) |
88 | 106 |
|
89 | | - with torch._C._ExcludeDispatchKeyGuard( |
90 | | - torch._C.DispatchKeySet(torch._C.DispatchKey.AutogradCPU) |
91 | | - ): |
92 | | - res = executorch_call_delegate(lowered_module, *args) |
| 107 | + if requires_grad: |
| 108 | + # Create aliases of the output that has requires_grad=True. We need |
| 109 | + # at least one of the inputs to err_fn to require grad so that the |
| 110 | + # output will have a grad_fn. |
93 | 111 |
|
94 | | - if requires_grad: |
95 | | - # Create aliases of the output that has requires_grad=True. We need |
96 | | - # at least one of the inputs to err_fn to require grad so that the |
97 | | - # output will have a grad_fn. |
| 112 | + # pyre-ignore |
| 113 | + def fake_requires_grad(var): |
| 114 | + if var is not None: |
| 115 | + var = var.detach() |
| 116 | + if torch.is_floating_point(var) or torch.is_complex(var): |
| 117 | + var.requires_grad = True |
| 118 | + return var |
98 | 119 |
|
99 | | - # pyre-ignore |
100 | | - def fake_requires_grad(var): |
101 | | - if var is not None: |
102 | | - var = var.detach() |
103 | | - if torch.is_floating_point(var) or torch.is_complex(var): |
104 | | - var.requires_grad = True |
105 | | - return var |
| 120 | + return pytree.tree_map_only(torch.Tensor, fake_requires_grad, res) |
106 | 121 |
|
107 | | - return pytree.tree_map_only(torch.Tensor, fake_requires_grad, res) |
| 122 | + return res |
108 | 123 |
|
| 124 | + @executorch_call_delegate.py_impl(ProxyTorchDispatchMode) |
| 125 | + # pyre-ignore |
| 126 | + def call_delegate_proxy_torch_dispatch_mode(mode, lowered_module, *args): |
| 127 | + res = trace_call_delegate(mode, executorch_call_delegate, lowered_module, *args) |
109 | 128 | return res |
110 | 129 |
|
| 130 | + @executorch_call_delegate.py_impl(FakeTensorMode) |
| 131 | + # pyre-ignore |
| 132 | + def call_delegate_fake_tensor_mode(mode, lowered_module, *args): |
| 133 | + with mode: |
| 134 | + return call_delegate_cpu(lowered_module, *args) |
111 | 135 |
|
112 | | -@executorch_call_delegate.py_impl(ProxyTorchDispatchMode) |
113 | | -# pyre-ignore |
114 | | -def call_delegate_proxy_torch_dispatch_mode(mode, lowered_module, *args): |
115 | | - res = trace_call_delegate(mode, executorch_call_delegate, lowered_module, *args) |
116 | | - return res |
117 | | - |
118 | | - |
119 | | -@executorch_call_delegate.py_impl(FakeTensorMode) |
120 | | -# pyre-ignore |
121 | | -def call_delegate_fake_tensor_mode(mode, lowered_module, *args): |
122 | | - with mode: |
123 | | - return call_delegate_cpu(lowered_module, *args) |
124 | | - |
125 | | - |
126 | | -@executorch_call_delegate.py_functionalize_impl |
127 | | -# pyre-ignore |
128 | | -def call_delegate_functionalize(ctx, lowered_module, *args): |
129 | | - unwrapped_args = tuple(ctx.unwrap_tensors(arg) for arg in args) |
130 | | - with ctx.redispatch_to_next(): |
131 | | - res = executorch_call_delegate(lowered_module, *unwrapped_args) |
132 | | - return ctx.wrap_tensors(res) |
133 | | - |
134 | | - |
135 | | -# pyre-ignore: Missing parameter annotation [2]: Parameter `obj` must have a type other than `Any`.Pyre |
136 | | -def is_lowered_module(obj: Any) -> bool: |
137 | | - """ |
138 | | - This function is added to avoid using isinstance(obj, LoweredBackendModule) as it will import LoweredBackendModule, which may cause a circular import. |
139 | | - """ |
140 | | - return type(obj).__name__ == LOWERED_BACKEND_MODULE_TYPE |
141 | | - |
142 | | - |
143 | | -def get_lowered_module_name( |
144 | | - root: torch.nn.Module, |
145 | | - # pyre-ignore: Undefined or invalid type [11]: Annotation `LoweredBackendModule` is not defined as a type. |
146 | | - lowered_module: LOWERED_BACKEND_MODULE_TYPE, # noqa |
147 | | -) -> str: |
148 | | - """ |
149 | | - Adds the given lowered_module into the given root module and returns the |
150 | | - name of the module added. |
151 | | - """ |
152 | | - # Find a qualifying name for the lowered submodule |
153 | | - qualname = None |
154 | | - i = 0 |
155 | | - while True: |
156 | | - qualname = f"lowered_module_{i}" |
157 | | - if not hasattr(root, qualname): |
158 | | - break |
159 | | - i += 1 |
160 | | - assert qualname is not None |
161 | | - |
162 | | - root.add_module(qualname, lowered_module) |
163 | | - return qualname |
| 136 | + @executorch_call_delegate.py_functionalize_impl |
| 137 | + # pyre-ignore |
| 138 | + def call_delegate_functionalize(ctx, lowered_module, *args): |
| 139 | + unwrapped_args = tuple(ctx.unwrap_tensors(arg) for arg in args) |
| 140 | + with ctx.redispatch_to_next(): |
| 141 | + res = executorch_call_delegate(lowered_module, *unwrapped_args) |
| 142 | + return ctx.wrap_tensors(res) |
| 143 | + |
| 144 | + # pyre-ignore: Missing parameter annotation [2]: Parameter `obj` must have a type other than `Any`.Pyre |
| 145 | + def is_lowered_module(obj: Any) -> bool: |
| 146 | + """ |
| 147 | + This function is added to avoid using isinstance(obj, LoweredBackendModule) as it will import LoweredBackendModule, which may cause a circular import. |
| 148 | + """ |
| 149 | + return type(obj).__name__ == LOWERED_BACKEND_MODULE_TYPE |
| 150 | + |
| 151 | + def get_lowered_module_name( |
| 152 | + root: torch.nn.Module, |
| 153 | + # pyre-ignore: Undefined or invalid type [11]: Annotation `LoweredBackendModule` is not defined as a type. |
| 154 | + lowered_module: LOWERED_BACKEND_MODULE_TYPE, # noqa |
| 155 | + ) -> str: |
| 156 | + """ |
| 157 | + Adds the given lowered_module into the given root module and returns the |
| 158 | + name of the module added. |
| 159 | + """ |
| 160 | + # Find a qualifying name for the lowered submodule |
| 161 | + qualname = None |
| 162 | + i = 0 |
| 163 | + while True: |
| 164 | + qualname = f"lowered_module_{i}" |
| 165 | + if not hasattr(root, qualname): |
| 166 | + break |
| 167 | + i += 1 |
| 168 | + assert qualname is not None |
| 169 | + |
| 170 | + root.add_module(qualname, lowered_module) |
| 171 | + return qualname |
0 commit comments