Skip to content

Commit f19f9d9

Browse files
authored
Move exir.delegate to PyTorch core to enforce no out-of-tree HOPs
Differential Revision: D60674615 Pull Request resolved: #4521
1 parent 9a4f32d commit f19f9d9

File tree

3 files changed

+163
-133
lines changed

3 files changed

+163
-133
lines changed

exir/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ python_library(
108108
name = "delegate",
109109
srcs = [
110110
"delegate.py",
111+
"delegate.pyi",
111112
],
112113
deps = [
113114
"//caffe2:torch",

exir/delegate.py

Lines changed: 141 additions & 133 deletions
Original file line numberDiff line numberDiff line change
@@ -8,156 +8,164 @@
88

99
from __future__ import annotations
1010

11-
from typing import Any, cast
11+
try: # noqa: C901
12+
from torch._higher_order_ops.executorch_call_delegate import (
13+
executorch_call_delegate as executorch_call_delegate,
14+
get_lowered_module_name as get_lowered_module_name,
15+
is_lowered_module as is_lowered_module,
16+
)
17+
18+
except ImportError:
1219

13-
import torch
14-
import torch.utils._pytree as pytree
15-
from torch._ops import HigherOrderOperator
16-
from torch._subclasses.fake_tensor import FakeTensorMode
17-
from torch.fx.experimental.proxy_tensor import (
18-
disable_proxy_modes_tracing,
19-
get_proxy_slot,
20-
ProxyTorchDispatchMode,
21-
track_tensor_tree,
22-
)
23-
from torch.utils._pytree import tree_flatten
20+
# TODO: Delete this code once pytorch pin advances
2421

22+
from typing import Any, cast
2523

26-
executorch_call_delegate = HigherOrderOperator("executorch_call_delegate")
27-
executorch_call_delegate.fallthrough(torch._C.DispatchKey.PythonDispatcher)
28-
executorch_call_delegate.fallthrough(torch._C.DispatchKey.PythonTLSSnapshot)
29-
executorch_call_delegate.fallthrough(torch._C.DispatchKey.ADInplaceOrView)
30-
executorch_call_delegate.fallthrough(torch._C.DispatchKey.AutocastCPU)
24+
import torch
25+
import torch.utils._pytree as pytree
26+
from torch._ops import HigherOrderOperator
27+
from torch._subclasses.fake_tensor import FakeTensorMode
28+
from torch.fx.experimental.proxy_tensor import (
29+
disable_proxy_modes_tracing,
30+
get_proxy_slot,
31+
ProxyTorchDispatchMode,
32+
track_tensor_tree,
33+
)
34+
from torch.utils._pytree import tree_flatten
3135

32-
LOWERED_BACKEND_MODULE_TYPE = "LoweredBackendModule"
36+
executorch_call_delegate = HigherOrderOperator("executorch_call_delegate")
37+
executorch_call_delegate.fallthrough(torch._C.DispatchKey.PythonDispatcher)
38+
executorch_call_delegate.fallthrough(torch._C.DispatchKey.PythonTLSSnapshot)
39+
executorch_call_delegate.fallthrough(torch._C.DispatchKey.ADInplaceOrView)
40+
executorch_call_delegate.fallthrough(torch._C.DispatchKey.AutocastCPU)
3341

42+
LOWERED_BACKEND_MODULE_TYPE = "LoweredBackendModule"
3443

35-
# pyre-ignore
36-
def trace_call_delegate(proxy_mode, func_overload, lowered_module, *args):
3744
# pyre-ignore
38-
def _unwrap_proxy(e):
39-
if not isinstance(e, (torch.Tensor, torch.SymInt, torch.SymFloat)):
40-
return e
41-
return get_proxy_slot(
42-
cast(torch.Tensor, e), proxy_mode.tracer, e, lambda e: e.proxy
45+
def trace_call_delegate(proxy_mode, func_overload, lowered_module, *args):
46+
# pyre-ignore
47+
def _unwrap_proxy(e):
48+
if not isinstance(e, (torch.Tensor, torch.SymInt, torch.SymFloat)):
49+
return e
50+
return get_proxy_slot(
51+
cast(torch.Tensor, e), proxy_mode.tracer, e, lambda e: e.proxy
52+
)
53+
54+
if not is_lowered_module(lowered_module):
55+
raise ValueError(
56+
"executorch_call_delegate()'s first argument must be a LoweredBackendModule"
57+
)
58+
59+
with disable_proxy_modes_tracing():
60+
out = call_delegate_cpu(lowered_module, *args)
61+
62+
get_lowered_module_name(proxy_mode.tracer.root, lowered_module)
63+
64+
node_args = (lowered_module, *args)
65+
proxy_args = pytree.tree_map(_unwrap_proxy, node_args)
66+
out_proxy = proxy_mode.tracer.create_proxy(
67+
"call_function",
68+
func_overload,
69+
proxy_args,
70+
{},
71+
name="executorch_call_delegate",
4372
)
44-
45-
if not is_lowered_module(lowered_module):
46-
raise ValueError(
47-
"executorch_call_delegate()'s first argument must be a LoweredBackendModule"
73+
return track_tensor_tree(
74+
out, out_proxy, constant=None, tracer=proxy_mode.tracer
4875
)
4976

50-
with disable_proxy_modes_tracing():
51-
out = call_delegate_cpu(lowered_module, *args)
52-
53-
get_lowered_module_name(proxy_mode.tracer.root, lowered_module)
54-
55-
node_args = (lowered_module, *args)
56-
proxy_args = pytree.tree_map(_unwrap_proxy, node_args)
57-
out_proxy = proxy_mode.tracer.create_proxy(
58-
"call_function", func_overload, proxy_args, {}, name="executorch_call_delegate"
59-
)
60-
return track_tensor_tree(out, out_proxy, constant=None, tracer=proxy_mode.tracer)
61-
62-
63-
@executorch_call_delegate.py_impl(torch._C.DispatchKey.CompositeExplicitAutograd)
64-
# pyre-ignore
65-
def call_delegate_cpu(lowered_module, *args):
66-
# FX creates this immutable_dict/list concept. Get rid of this.
67-
map_types = {
68-
torch.fx.immutable_collections.immutable_dict: dict,
69-
torch.fx.immutable_collections.immutable_list: list,
70-
}
71-
new_args = pytree.tree_map_only(
72-
tuple(map_types.keys()),
73-
lambda a: map_types[type(a)](a),
74-
args,
75-
lambda a: isinstance(a, tuple(map_types.keys())),
76-
)
77-
return lowered_module.original_module.module()(*new_args)
77+
@executorch_call_delegate.py_impl(torch._C.DispatchKey.CompositeExplicitAutograd)
78+
# pyre-ignore
79+
def call_delegate_cpu(lowered_module, *args):
80+
# FX creates this immutable_dict/list concept. Get rid of this.
81+
map_types = {
82+
torch.fx.immutable_collections.immutable_dict: dict,
83+
torch.fx.immutable_collections.immutable_list: list,
84+
}
85+
new_args = pytree.tree_map_only(
86+
tuple(map_types.keys()),
87+
lambda a: map_types[type(a)](a),
88+
args,
89+
lambda a: isinstance(a, tuple(map_types.keys())),
90+
)
91+
return lowered_module.original_module.module()(*new_args)
7892

93+
@executorch_call_delegate.py_impl(torch._C.DispatchKey.Autograd)
94+
# pyre-ignore
95+
def call_delegate_autograd(lowered_module, *args):
96+
# TODO: support autograd
97+
flat_operands, _ = tree_flatten([lowered_module, *args])
98+
requires_grad = any(
99+
f.requires_grad for f in flat_operands if isinstance(f, torch.Tensor)
100+
)
79101

80-
@executorch_call_delegate.py_impl(torch._C.DispatchKey.Autograd)
81-
# pyre-ignore
82-
def call_delegate_autograd(lowered_module, *args):
83-
# TODO: support autograd
84-
flat_operands, _ = tree_flatten([lowered_module, *args])
85-
requires_grad = any(
86-
f.requires_grad for f in flat_operands if isinstance(f, torch.Tensor)
87-
)
102+
with torch._C._ExcludeDispatchKeyGuard(
103+
torch._C.DispatchKeySet(torch._C.DispatchKey.AutogradCPU)
104+
):
105+
res = executorch_call_delegate(lowered_module, *args)
88106

89-
with torch._C._ExcludeDispatchKeyGuard(
90-
torch._C.DispatchKeySet(torch._C.DispatchKey.AutogradCPU)
91-
):
92-
res = executorch_call_delegate(lowered_module, *args)
107+
if requires_grad:
108+
# Create aliases of the output that has requires_grad=True. We need
109+
# at least one of the inputs to err_fn to require grad so that the
110+
# output will have a grad_fn.
93111

94-
if requires_grad:
95-
# Create aliases of the output that has requires_grad=True. We need
96-
# at least one of the inputs to err_fn to require grad so that the
97-
# output will have a grad_fn.
112+
# pyre-ignore
113+
def fake_requires_grad(var):
114+
if var is not None:
115+
var = var.detach()
116+
if torch.is_floating_point(var) or torch.is_complex(var):
117+
var.requires_grad = True
118+
return var
98119

99-
# pyre-ignore
100-
def fake_requires_grad(var):
101-
if var is not None:
102-
var = var.detach()
103-
if torch.is_floating_point(var) or torch.is_complex(var):
104-
var.requires_grad = True
105-
return var
120+
return pytree.tree_map_only(torch.Tensor, fake_requires_grad, res)
106121

107-
return pytree.tree_map_only(torch.Tensor, fake_requires_grad, res)
122+
return res
108123

124+
@executorch_call_delegate.py_impl(ProxyTorchDispatchMode)
125+
# pyre-ignore
126+
def call_delegate_proxy_torch_dispatch_mode(mode, lowered_module, *args):
127+
res = trace_call_delegate(mode, executorch_call_delegate, lowered_module, *args)
109128
return res
110129

130+
@executorch_call_delegate.py_impl(FakeTensorMode)
131+
# pyre-ignore
132+
def call_delegate_fake_tensor_mode(mode, lowered_module, *args):
133+
with mode:
134+
return call_delegate_cpu(lowered_module, *args)
111135

112-
@executorch_call_delegate.py_impl(ProxyTorchDispatchMode)
113-
# pyre-ignore
114-
def call_delegate_proxy_torch_dispatch_mode(mode, lowered_module, *args):
115-
res = trace_call_delegate(mode, executorch_call_delegate, lowered_module, *args)
116-
return res
117-
118-
119-
@executorch_call_delegate.py_impl(FakeTensorMode)
120-
# pyre-ignore
121-
def call_delegate_fake_tensor_mode(mode, lowered_module, *args):
122-
with mode:
123-
return call_delegate_cpu(lowered_module, *args)
124-
125-
126-
@executorch_call_delegate.py_functionalize_impl
127-
# pyre-ignore
128-
def call_delegate_functionalize(ctx, lowered_module, *args):
129-
unwrapped_args = tuple(ctx.unwrap_tensors(arg) for arg in args)
130-
with ctx.redispatch_to_next():
131-
res = executorch_call_delegate(lowered_module, *unwrapped_args)
132-
return ctx.wrap_tensors(res)
133-
134-
135-
# pyre-ignore: Missing parameter annotation [2]: Parameter `obj` must have a type other than `Any`.Pyre
136-
def is_lowered_module(obj: Any) -> bool:
137-
"""
138-
This function is added to avoid using isinstance(obj, LoweredBackendModule) as it will import LoweredBackendModule, which may cause a circular import.
139-
"""
140-
return type(obj).__name__ == LOWERED_BACKEND_MODULE_TYPE
141-
142-
143-
def get_lowered_module_name(
144-
root: torch.nn.Module,
145-
# pyre-ignore: Undefined or invalid type [11]: Annotation `LoweredBackendModule` is not defined as a type.
146-
lowered_module: LOWERED_BACKEND_MODULE_TYPE, # noqa
147-
) -> str:
148-
"""
149-
Adds the given lowered_module into the given root module and returns the
150-
name of the module added.
151-
"""
152-
# Find a qualifying name for the lowered submodule
153-
qualname = None
154-
i = 0
155-
while True:
156-
qualname = f"lowered_module_{i}"
157-
if not hasattr(root, qualname):
158-
break
159-
i += 1
160-
assert qualname is not None
161-
162-
root.add_module(qualname, lowered_module)
163-
return qualname
136+
@executorch_call_delegate.py_functionalize_impl
137+
# pyre-ignore
138+
def call_delegate_functionalize(ctx, lowered_module, *args):
139+
unwrapped_args = tuple(ctx.unwrap_tensors(arg) for arg in args)
140+
with ctx.redispatch_to_next():
141+
res = executorch_call_delegate(lowered_module, *unwrapped_args)
142+
return ctx.wrap_tensors(res)
143+
144+
# pyre-ignore: Missing parameter annotation [2]: Parameter `obj` must have a type other than `Any`.Pyre
145+
def is_lowered_module(obj: Any) -> bool:
146+
"""
147+
This function is added to avoid using isinstance(obj, LoweredBackendModule) as it will import LoweredBackendModule, which may cause a circular import.
148+
"""
149+
return type(obj).__name__ == LOWERED_BACKEND_MODULE_TYPE
150+
151+
def get_lowered_module_name(
152+
root: torch.nn.Module,
153+
# pyre-ignore: Undefined or invalid type [11]: Annotation `LoweredBackendModule` is not defined as a type.
154+
lowered_module: LOWERED_BACKEND_MODULE_TYPE, # noqa
155+
) -> str:
156+
"""
157+
Adds the given lowered_module into the given root module and returns the
158+
name of the module added.
159+
"""
160+
# Find a qualifying name for the lowered submodule
161+
qualname = None
162+
i = 0
163+
while True:
164+
qualname = f"lowered_module_{i}"
165+
if not hasattr(root, qualname):
166+
break
167+
i += 1
168+
assert qualname is not None
169+
170+
root.add_module(qualname, lowered_module)
171+
return qualname

exir/delegate.pyi

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# TODO: delete this when pytorch pin advances
8+
9+
from typing import Any
10+
11+
import torch
12+
from torch._ops import HigherOrderOperator
13+
14+
executorch_call_delegate: HigherOrderOperator
15+
16+
def is_lowered_module(obj: Any) -> bool: ...
17+
def get_lowered_module_name(
18+
root: torch.nn.Module,
19+
# pyre-ignore: Undefined or invalid type [11]: Annotation `LoweredBackendModule` is not defined as a type.
20+
lowered_module: LOWERED_BACKEND_MODULE_TYPE, # noqa
21+
) -> str: ...

0 commit comments

Comments
 (0)