|
15 | 15 | from torch.export.exported_program import ExportGraphSignature |
16 | 16 | from torch.fx.node import Node |
17 | 17 | from torch.utils import _pytree as pytree |
| 18 | +from torch.fx.passes.infra.pass_base import PassResult |
18 | 19 |
|
19 | 20 |
|
20 | 21 | # pyre-ignore |
@@ -45,109 +46,41 @@ def _is_mutable_buffer( |
45 | 46 | return True |
46 | 47 | return False |
47 | 48 |
|
48 | | - |
49 | | -class SpecPropPass(ExportPass): |
50 | | - def __init__(self) -> None: |
51 | | - super().__init__() |
52 | | - |
53 | | - def on_attr(self, attr: ProxyValue) -> None: |
54 | | - attr.node.meta["spec"] = pytree.tree_map_only( |
55 | | - torch.Tensor, |
56 | | - make_spec, |
57 | | - attr.data, |
58 | | - ) |
59 | | - |
60 | | - def update_placeholder_tensor_specs( |
61 | | - self, |
62 | | - exported_program: torch.export.ExportedProgram, |
63 | | - graph_module: torch.fx.GraphModule, |
64 | | - ) -> None: |
65 | | - """ |
66 | | - Update the tensor specs for all placeholder nodes such that |
67 | | - placeholders that are parameters are marked as constant. |
68 | | - """ |
69 | | - for node in graph_module.graph.nodes: |
70 | | - if node.op != "placeholder": |
71 | | - continue |
72 | | - if "spec" not in node.meta: |
73 | | - raise RuntimeError(f"Placeholder node {node} missing meta['spec']") |
74 | | - spec = node.meta["spec"] |
75 | | - if isinstance(node.target, str) and ( |
76 | | - node.target in exported_program.graph_signature.inputs_to_parameters |
77 | | - or ( |
78 | | - node.target in exported_program.graph_signature.inputs_to_buffers |
79 | | - and not _is_mutable_buffer(node, exported_program.graph_signature) |
80 | | - ) |
81 | | - or node.target |
82 | | - in exported_program.graph_signature.inputs_to_lifted_tensor_constants |
83 | | - ): |
84 | | - spec.const = True |
85 | | - |
86 | | - # pyre-ignore |
87 | | - def placeholder(self, name: str, arg, meta): |
88 | | - meta["spec"] = make_spec(arg) |
89 | | - return super().placeholder(name, arg, meta) |
90 | | - |
91 | | - # pyre-ignore |
92 | | - def call_operator(self, op, args, kwargs, meta): |
93 | | - args_data, kwargs_data = pytree.tree_map_only( |
94 | | - ProxyValue, lambda x: x.data, (args, kwargs) |
95 | | - ) |
96 | | - meta["spec"] = pytree.tree_map(make_spec, op(*args_data, **kwargs_data)) |
97 | | - return super().call_operator(op, args, kwargs, meta) |
98 | | - |
99 | | - # pyre-ignore |
100 | | - def call_getitem(self, value, key: int, meta): |
101 | | - meta["spec"] = value.node.meta["spec"][key] |
102 | | - return super().call_getitem(value, key, meta) |
103 | | - |
104 | | - # pyre-ignore |
105 | | - def call_cond(self, pred, true_fn, false_fn, inputs, meta): |
106 | | - # true_fn/false_fn return tensors of the same shape, so we can pick |
107 | | - # either one here. |
108 | | - *_, true_out_node = true_fn.graph.nodes |
109 | | - meta["spec"] = pytree.tree_map(make_spec, true_out_node.meta["val"]) |
110 | | - return super().call_cond(pred, true_fn, false_fn, inputs, meta) |
111 | | - |
112 | | - def call_map( |
113 | | - self, |
114 | | - f: torch.fx.GraphModule, |
115 | | - mapped_args: List[ProxyValue], |
116 | | - operands: List[ProxyValue], |
117 | | - meta: NodeMetadata, |
118 | | - ) -> ProxyValue: |
119 | | - mapped_dim_size = [arg.data for arg in mapped_args][0].size(0) |
120 | | - *_, body_out_node = f.graph.nodes |
121 | | - body_out_node_fake_tensor = body_out_node.meta["val"] |
122 | | - map_fake_tensor = pytree.tree_map_only( |
123 | | - torch.Tensor, |
124 | | - lambda x: x.new_empty(mapped_dim_size, *x.shape), |
125 | | - body_out_node_fake_tensor, |
126 | | - ) |
127 | | - meta["spec"] = pytree.tree_map(make_spec, map_fake_tensor) |
128 | | - return super().call_map(f, mapped_args, operands, meta) |
129 | | - |
130 | | - # pyre-ignore |
131 | | - def call_delegate(self, lowered_module, args, kwargs, meta): |
132 | | - args_data, kwargs_data = pytree.tree_map_only( |
133 | | - ProxyValue, lambda x: x.data, (args, kwargs) |
134 | | - ) |
135 | | - # If spec is missing, re-genenrate it with args data |
136 | | - if "spec" not in meta: |
137 | | - meta["spec"] = pytree.tree_map( |
138 | | - make_spec, |
139 | | - executorch_call_delegate(lowered_module, *args_data), |
| 49 | +def SpecPropPass(gm: torch.fx.GraphModule) -> PassResult: |
| 50 | + # Update all the meta["val"] |
| 51 | + pass_result = ExportPass()(gm) |
| 52 | + assert pass_result is not None |
| 53 | + gm = pass_result.graph_module |
| 54 | + # set node.meta["spec"] based on meta["val"] |
| 55 | + for module in gm.modules(): |
| 56 | + if isinstance(module, torch.fx.GraphModule): |
| 57 | + for node in module.graph.nodes: |
| 58 | + if node.op == "get_attr": |
| 59 | + continue |
| 60 | + node.meta["spec"] = pytree.tree_map(lambda meta_val: make_spec(meta_val), node.meta["val"]) |
| 61 | + return pass_result |
| 62 | + |
| 63 | +def update_placeholder_tensor_specs( |
| 64 | + exported_program: torch.export.ExportedProgram, |
| 65 | + graph_module: torch.fx.GraphModule, |
| 66 | +) -> None: |
| 67 | + """ |
| 68 | + Update the tensor specs for all placeholder nodes such that |
| 69 | + placeholders that are parameters are marked as constant. |
| 70 | + """ |
| 71 | + for node in graph_module.graph.nodes: |
| 72 | + if node.op != "placeholder": |
| 73 | + continue |
| 74 | + if "spec" not in node.meta: |
| 75 | + raise RuntimeError(f"Placeholder node {node} missing meta['spec']") |
| 76 | + spec = node.meta["spec"] |
| 77 | + if isinstance(node.target, str) and ( |
| 78 | + node.target in exported_program.graph_signature.inputs_to_parameters |
| 79 | + or ( |
| 80 | + node.target in exported_program.graph_signature.inputs_to_buffers |
| 81 | + and not _is_mutable_buffer(node, exported_program.graph_signature) |
140 | 82 | ) |
141 | | - return super().call_delegate(lowered_module, args, kwargs, meta) |
142 | | - |
143 | | - # pyre-ignore |
144 | | - def output(self, results, meta): |
145 | | - # pyre-ignore |
146 | | - def get_spec(x): |
147 | | - if isinstance(x, ProxyValue): |
148 | | - return x.node.meta["spec"] |
149 | | - else: |
150 | | - return make_spec(x) |
151 | | - |
152 | | - meta["spec"] = pytree.tree_map(get_spec, results) |
153 | | - return super().output(results, meta) |
| 83 | + or node.target |
| 84 | + in exported_program.graph_signature.inputs_to_lifted_tensor_constants |
| 85 | + ): |
| 86 | + spec.const = True |
0 commit comments