Skip to content

Commit 03bb3e4

Browse files
ydwu4facebook-github-bot
authored andcommitted
fix spec_prop_pass (#7974)
Summary: fix spec_prop_pass by just consulting node.meta["val"] Differential Revision: D66996237
1 parent 0733973 commit 03bb3e4

File tree

1 file changed

+17
-81
lines changed

1 file changed

+17
-81
lines changed

exir/passes/spec_prop_pass.py

Lines changed: 17 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from torch.export.exported_program import ExportGraphSignature
1616
from torch.fx.node import Node
1717
from torch.utils import _pytree as pytree
18+
from torch.fx.passes.infra.pass_base import PassResult
1819

1920

2021
# pyre-ignore
@@ -44,19 +45,9 @@ def _is_mutable_buffer(
4445
if fqn in graph_signature.buffers_to_mutate.values():
4546
return True
4647
return False
47-
48-
49-
class SpecPropPass(ExportPass):
50-
def __init__(self) -> None:
51-
super().__init__()
52-
53-
def on_attr(self, attr: ProxyValue) -> None:
54-
attr.node.meta["spec"] = pytree.tree_map_only(
55-
torch.Tensor,
56-
make_spec,
57-
attr.data,
58-
)
59-
48+
class SpecPropPass:
49+
def __call__(self, gm: torch.fx.GraphModule) -> PassResult:
50+
return spec_prop_pass(gm)
6051
def update_placeholder_tensor_specs(
6152
self,
6253
exported_program: torch.export.ExportedProgram,
@@ -83,71 +74,16 @@ def update_placeholder_tensor_specs(
8374
):
8475
spec.const = True
8576

86-
# pyre-ignore
87-
def placeholder(self, name: str, arg, meta):
88-
meta["spec"] = make_spec(arg)
89-
return super().placeholder(name, arg, meta)
90-
91-
# pyre-ignore
92-
def call_operator(self, op, args, kwargs, meta):
93-
args_data, kwargs_data = pytree.tree_map_only(
94-
ProxyValue, lambda x: x.data, (args, kwargs)
95-
)
96-
meta["spec"] = pytree.tree_map(make_spec, op(*args_data, **kwargs_data))
97-
return super().call_operator(op, args, kwargs, meta)
98-
99-
# pyre-ignore
100-
def call_getitem(self, value, key: int, meta):
101-
meta["spec"] = value.node.meta["spec"][key]
102-
return super().call_getitem(value, key, meta)
103-
104-
# pyre-ignore
105-
def call_cond(self, pred, true_fn, false_fn, inputs, meta):
106-
# true_fn/false_fn return tensors of the same shape, so we can pick
107-
# either one here.
108-
*_, true_out_node = true_fn.graph.nodes
109-
meta["spec"] = pytree.tree_map(make_spec, true_out_node.meta["val"])
110-
return super().call_cond(pred, true_fn, false_fn, inputs, meta)
111-
112-
def call_map(
113-
self,
114-
f: torch.fx.GraphModule,
115-
mapped_args: List[ProxyValue],
116-
operands: List[ProxyValue],
117-
meta: NodeMetadata,
118-
) -> ProxyValue:
119-
mapped_dim_size = [arg.data for arg in mapped_args][0].size(0)
120-
*_, body_out_node = f.graph.nodes
121-
body_out_node_fake_tensor = body_out_node.meta["val"]
122-
map_fake_tensor = pytree.tree_map_only(
123-
torch.Tensor,
124-
lambda x: x.new_empty(mapped_dim_size, *x.shape),
125-
body_out_node_fake_tensor,
126-
)
127-
meta["spec"] = pytree.tree_map(make_spec, map_fake_tensor)
128-
return super().call_map(f, mapped_args, operands, meta)
129-
130-
# pyre-ignore
131-
def call_delegate(self, lowered_module, args, kwargs, meta):
132-
args_data, kwargs_data = pytree.tree_map_only(
133-
ProxyValue, lambda x: x.data, (args, kwargs)
134-
)
135-
# If spec is missing, re-genenrate it with args data
136-
if "spec" not in meta:
137-
meta["spec"] = pytree.tree_map(
138-
make_spec,
139-
executorch_call_delegate(lowered_module, *args_data),
140-
)
141-
return super().call_delegate(lowered_module, args, kwargs, meta)
142-
143-
# pyre-ignore
144-
def output(self, results, meta):
145-
# pyre-ignore
146-
def get_spec(x):
147-
if isinstance(x, ProxyValue):
148-
return x.node.meta["spec"]
149-
else:
150-
return make_spec(x)
151-
152-
meta["spec"] = pytree.tree_map(get_spec, results)
153-
return super().output(results, meta)
77+
def spec_prop_pass(gm: torch.fx.GraphModule) -> PassResult:
78+
# Update all the meta["val"]
79+
pass_result = ExportPass()(gm)
80+
assert pass_result is not None
81+
gm = pass_result.graph_module
82+
# set node.meta["spec"] based on meta["val"]
83+
for module in gm.modules():
84+
if isinstance(module, torch.fx.GraphModule):
85+
for node in module.graph.nodes:
86+
if node.op == "get_attr":
87+
continue
88+
node.meta["spec"] = pytree.tree_map(lambda meta_val: make_spec(meta_val), node.meta["val"])
89+
return pass_result

0 commit comments

Comments
 (0)