Skip to content

Commit d764cbe

Browse files
ydwu4facebook-github-bot
authored andcommitted
fix spec_prop_pass
Summary: fix spec_prop_pass by just consulting node.meta["val"] Differential Revision: D66996237
1 parent 270271b commit d764cbe

File tree

2 files changed

+42
-109
lines changed

2 files changed

+42
-109
lines changed

exir/passes/spec_prop_pass.py

Lines changed: 38 additions & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from torch.export.exported_program import ExportGraphSignature
1616
from torch.fx.node import Node
1717
from torch.utils import _pytree as pytree
18+
from torch.fx.passes.infra.pass_base import PassResult
1819

1920

2021
# pyre-ignore
@@ -45,109 +46,41 @@ def _is_mutable_buffer(
4546
return True
4647
return False
4748

48-
49-
class SpecPropPass(ExportPass):
50-
def __init__(self) -> None:
51-
super().__init__()
52-
53-
def on_attr(self, attr: ProxyValue) -> None:
54-
attr.node.meta["spec"] = pytree.tree_map_only(
55-
torch.Tensor,
56-
make_spec,
57-
attr.data,
58-
)
59-
60-
def update_placeholder_tensor_specs(
61-
self,
62-
exported_program: torch.export.ExportedProgram,
63-
graph_module: torch.fx.GraphModule,
64-
) -> None:
65-
"""
66-
Update the tensor specs for all placeholder nodes such that
67-
placeholders that are parameters are marked as constant.
68-
"""
69-
for node in graph_module.graph.nodes:
70-
if node.op != "placeholder":
71-
continue
72-
if "spec" not in node.meta:
73-
raise RuntimeError(f"Placeholder node {node} missing meta['spec']")
74-
spec = node.meta["spec"]
75-
if isinstance(node.target, str) and (
76-
node.target in exported_program.graph_signature.inputs_to_parameters
77-
or (
78-
node.target in exported_program.graph_signature.inputs_to_buffers
79-
and not _is_mutable_buffer(node, exported_program.graph_signature)
80-
)
81-
or node.target
82-
in exported_program.graph_signature.inputs_to_lifted_tensor_constants
83-
):
84-
spec.const = True
85-
86-
# pyre-ignore
87-
def placeholder(self, name: str, arg, meta):
88-
meta["spec"] = make_spec(arg)
89-
return super().placeholder(name, arg, meta)
90-
91-
# pyre-ignore
92-
def call_operator(self, op, args, kwargs, meta):
93-
args_data, kwargs_data = pytree.tree_map_only(
94-
ProxyValue, lambda x: x.data, (args, kwargs)
95-
)
96-
meta["spec"] = pytree.tree_map(make_spec, op(*args_data, **kwargs_data))
97-
return super().call_operator(op, args, kwargs, meta)
98-
99-
# pyre-ignore
100-
def call_getitem(self, value, key: int, meta):
101-
meta["spec"] = value.node.meta["spec"][key]
102-
return super().call_getitem(value, key, meta)
103-
104-
# pyre-ignore
105-
def call_cond(self, pred, true_fn, false_fn, inputs, meta):
106-
# true_fn/false_fn return tensors of the same shape, so we can pick
107-
# either one here.
108-
*_, true_out_node = true_fn.graph.nodes
109-
meta["spec"] = pytree.tree_map(make_spec, true_out_node.meta["val"])
110-
return super().call_cond(pred, true_fn, false_fn, inputs, meta)
111-
112-
def call_map(
113-
self,
114-
f: torch.fx.GraphModule,
115-
mapped_args: List[ProxyValue],
116-
operands: List[ProxyValue],
117-
meta: NodeMetadata,
118-
) -> ProxyValue:
119-
mapped_dim_size = [arg.data for arg in mapped_args][0].size(0)
120-
*_, body_out_node = f.graph.nodes
121-
body_out_node_fake_tensor = body_out_node.meta["val"]
122-
map_fake_tensor = pytree.tree_map_only(
123-
torch.Tensor,
124-
lambda x: x.new_empty(mapped_dim_size, *x.shape),
125-
body_out_node_fake_tensor,
126-
)
127-
meta["spec"] = pytree.tree_map(make_spec, map_fake_tensor)
128-
return super().call_map(f, mapped_args, operands, meta)
129-
130-
# pyre-ignore
131-
def call_delegate(self, lowered_module, args, kwargs, meta):
132-
args_data, kwargs_data = pytree.tree_map_only(
133-
ProxyValue, lambda x: x.data, (args, kwargs)
134-
)
135-
# If spec is missing, re-genenrate it with args data
136-
if "spec" not in meta:
137-
meta["spec"] = pytree.tree_map(
138-
make_spec,
139-
executorch_call_delegate(lowered_module, *args_data),
49+
def SpecPropPass(gm: torch.fx.GraphModule) -> PassResult:
50+
# Update all the meta["val"]
51+
pass_result = ExportPass()(gm)
52+
assert pass_result is not None
53+
gm = pass_result.graph_module
54+
# set node.meta["spec"] based on meta["val"]
55+
for module in gm.modules():
56+
if isinstance(module, torch.fx.GraphModule):
57+
for node in module.graph.nodes:
58+
if node.op == "get_attr":
59+
continue
60+
node.meta["spec"] = pytree.tree_map(lambda meta_val: make_spec(meta_val), node.meta["val"])
61+
return pass_result
62+
63+
def update_placeholder_tensor_specs(
64+
exported_program: torch.export.ExportedProgram,
65+
graph_module: torch.fx.GraphModule,
66+
) -> None:
67+
"""
68+
Update the tensor specs for all placeholder nodes such that
69+
placeholders that are parameters are marked as constant.
70+
"""
71+
for node in graph_module.graph.nodes:
72+
if node.op != "placeholder":
73+
continue
74+
if "spec" not in node.meta:
75+
raise RuntimeError(f"Placeholder node {node} missing meta['spec']")
76+
spec = node.meta["spec"]
77+
if isinstance(node.target, str) and (
78+
node.target in exported_program.graph_signature.inputs_to_parameters
79+
or (
80+
node.target in exported_program.graph_signature.inputs_to_buffers
81+
and not _is_mutable_buffer(node, exported_program.graph_signature)
14082
)
141-
return super().call_delegate(lowered_module, args, kwargs, meta)
142-
143-
# pyre-ignore
144-
def output(self, results, meta):
145-
# pyre-ignore
146-
def get_spec(x):
147-
if isinstance(x, ProxyValue):
148-
return x.node.meta["spec"]
149-
else:
150-
return make_spec(x)
151-
152-
meta["spec"] = pytree.tree_map(get_spec, results)
153-
return super().output(results, meta)
83+
or node.target
84+
in exported_program.graph_signature.inputs_to_lifted_tensor_constants
85+
):
86+
spec.const = True

exir/program/_program.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@
5454
from executorch.exir.passes.replace_view_copy_with_view_pass import (
5555
ReplaceViewCopyWithViewPass,
5656
)
57-
from executorch.exir.passes.spec_prop_pass import SpecPropPass
57+
from executorch.exir.passes.spec_prop_pass import SpecPropPass, update_placeholder_tensor_specs
5858
from executorch.exir.passes.weights_to_outputs_pass import weights_to_outputs_pass
5959
from executorch.exir.print_program import pretty_print, print_program
6060
from executorch.exir.schema import Program
@@ -734,7 +734,7 @@ def edge_to_executorch_passes(
734734
"""
735735
passes: List[PassType] = [
736736
*config.passes,
737-
SpecPropPass(),
737+
SpecPropPass,
738738
# ExecuTorch backend ops are unable to handle unbacked symints. So after
739739
# this pass, passes cannot be Interpreter-based, because it will fail if
740740
# there exists an unbacked symint operation.
@@ -1390,7 +1390,7 @@ def to_executorch(
13901390
new_gm_res = p(new_gm)
13911391
assert new_gm_res is not None
13921392
new_gm = new_gm_res.graph_module
1393-
if isinstance(p, SpecPropPass):
1393+
if p is SpecPropPass:
13941394
# Note that this is a hacky way to get around the fact that
13951395
# placeholder nodes corresponding to the parameters of the graph module
13961396
# shall not participate in memory planning. It increases runtime memory
@@ -1401,7 +1401,7 @@ def to_executorch(
14011401
# Working with GraphModule does not provide all the information contained
14021402
# in the ExportedProgram
14031403
# TODO(who?)
1404-
p.update_placeholder_tensor_specs(program, new_gm)
1404+
update_placeholder_tensor_specs(program, new_gm)
14051405

14061406
# Extract constants if the config says too.
14071407
if config.external_constants:

0 commit comments

Comments
 (0)