-
Notifications
You must be signed in to change notification settings - Fork 747
Fix double-tracing in SpecPropPass #15485
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -6,14 +6,16 @@ | |||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| # pyre-strict | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| from typing import List, Optional | ||||||||||||||||||||||||||||||||||||||||
| import operator | ||||||||||||||||||||||||||||||||||||||||
| from typing import Optional | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| import torch | ||||||||||||||||||||||||||||||||||||||||
| from executorch.exir.delegate import executorch_call_delegate | ||||||||||||||||||||||||||||||||||||||||
| from executorch.exir.pass_base import ExportPass, NodeMetadata, ProxyValue | ||||||||||||||||||||||||||||||||||||||||
| from executorch.exir.pass_base import ExportPass, ProxyValue | ||||||||||||||||||||||||||||||||||||||||
| from executorch.exir.tensor import TensorSpec | ||||||||||||||||||||||||||||||||||||||||
| from torch.export.exported_program import ExportGraphSignature | ||||||||||||||||||||||||||||||||||||||||
| from torch.fx.node import Node | ||||||||||||||||||||||||||||||||||||||||
| from torch.fx.passes.infra.pass_base import PassResult | ||||||||||||||||||||||||||||||||||||||||
| from torch.utils import _pytree as pytree | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
|
|
@@ -52,12 +54,48 @@ class SpecPropPass(ExportPass): | |||||||||||||||||||||||||||||||||||||||
| def __init__(self) -> None: | ||||||||||||||||||||||||||||||||||||||||
| super().__init__() | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| def on_attr(self, attr: ProxyValue) -> None: | ||||||||||||||||||||||||||||||||||||||||
| attr.node.meta["spec"] = pytree.tree_map_only( | ||||||||||||||||||||||||||||||||||||||||
| torch.Tensor, | ||||||||||||||||||||||||||||||||||||||||
| make_spec, | ||||||||||||||||||||||||||||||||||||||||
| attr.data, | ||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||
| def __call__(self, graph_module: torch.fx.GraphModule) -> PassResult: | ||||||||||||||||||||||||||||||||||||||||
| # Re-trace metadata to ensure it's up to date. | ||||||||||||||||||||||||||||||||||||||||
| res = ExportPass()(graph_module) | ||||||||||||||||||||||||||||||||||||||||
| assert res is not None | ||||||||||||||||||||||||||||||||||||||||
GregoryComer marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||||||||||||||||||||
| gm = res.graph_module | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| def get_spec(x): | ||||||||||||||||||||||||||||||||||||||||
| if hasattr(x, "meta"): | ||||||||||||||||||||||||||||||||||||||||
| return x.meta.get("spec", None) | ||||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||||
| return None | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| for module in gm.modules(): | ||||||||||||||||||||||||||||||||||||||||
| if isinstance(module, torch.fx.GraphModule): | ||||||||||||||||||||||||||||||||||||||||
| for node in module.graph.nodes: | ||||||||||||||||||||||||||||||||||||||||
| meta_val = node.meta.get("val", None) | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| if node.op == "output": | ||||||||||||||||||||||||||||||||||||||||
| node.meta["spec"] = pytree.tree_map(get_spec, node.args[0]) | ||||||||||||||||||||||||||||||||||||||||
| elif node.op == "call_function" and node.target == operator.getitem: | ||||||||||||||||||||||||||||||||||||||||
| value_spec = pytree.tree_map(get_spec, node.args[0]) | ||||||||||||||||||||||||||||||||||||||||
| node.meta["spec"] = value_spec[node.args[1]] | ||||||||||||||||||||||||||||||||||||||||
| elif ( | ||||||||||||||||||||||||||||||||||||||||
| node.op == "call_function" | ||||||||||||||||||||||||||||||||||||||||
| and node.target == executorch_call_delegate | ||||||||||||||||||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||||||||||||||||||
| # Note: We currently rely on delegate node specs not being regenerated, | ||||||||||||||||||||||||||||||||||||||||
| # as the spec is set somewhat manually when adding the call delegate node. | ||||||||||||||||||||||||||||||||||||||||
| # If we regenerate, it can change and break lowering (it becomes a tuple?). | ||||||||||||||||||||||||||||||||||||||||
| # Ideally, we should figure out how to make the spec regeneration not break | ||||||||||||||||||||||||||||||||||||||||
| # things. | ||||||||||||||||||||||||||||||||||||||||
| # | ||||||||||||||||||||||||||||||||||||||||
| # We do need to regenerate non-call-delegate node specs, as this pass is called | ||||||||||||||||||||||||||||||||||||||||
| # multiple times in some lowering paths (backends can and do call it). | ||||||||||||||||||||||||||||||||||||||||
| if "spec" not in node.meta: | ||||||||||||||||||||||||||||||||||||||||
| node.meta["spec"] = pytree.tree_map(make_spec, meta_val) | ||||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||||
| node.meta["spec"] = pytree.tree_map(make_spec, meta_val) | ||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+91
to
+94
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Consider merging these 2 conditions?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There is some weird existing behavior here that seems to need to be preserved (barring a larger update). Basically, we don't want to regenerate call_delegate node specs but do want to regenerate everything else. I'll add a comment detailing why.
Comment on lines
+79
to
+94
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I meant something like:
Suggested change
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure - the issue that I've seen is that sometimes this pass gets called multiple times (Cadence backend does this, for example) and thus we need to regenerate the spec for most nodes to make sure they pick up on any shape changes between calls. But if we regenerate the spec for call_delegate nodes, it breaks things. So the Ideally, we'll do a deeper change to fix this but this preserves the existing behavior. I could change the line to |
||||||||||||||||||||||||||||||||||||||||
| return res | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| def call(self, graph_module: torch.fx.GraphModule) -> PassResult: | ||||||||||||||||||||||||||||||||||||||||
| return self(graph_module) | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| def update_placeholder_tensor_specs( | ||||||||||||||||||||||||||||||||||||||||
| self, | ||||||||||||||||||||||||||||||||||||||||
|
|
@@ -84,85 +122,3 @@ def update_placeholder_tensor_specs( | |||||||||||||||||||||||||||||||||||||||
| in exported_program.graph_signature.inputs_to_lifted_tensor_constants | ||||||||||||||||||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||||||||||||||||||
| spec.const = True | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| # pyre-ignore | ||||||||||||||||||||||||||||||||||||||||
| def placeholder(self, name: str, arg, meta): | ||||||||||||||||||||||||||||||||||||||||
| meta["spec"] = make_spec(arg) | ||||||||||||||||||||||||||||||||||||||||
| return super().placeholder(name, arg, meta) | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| # pyre-ignore | ||||||||||||||||||||||||||||||||||||||||
| def call_operator(self, op, args, kwargs, meta): | ||||||||||||||||||||||||||||||||||||||||
| args_data, kwargs_data = pytree.tree_map_only( | ||||||||||||||||||||||||||||||||||||||||
| ProxyValue, lambda x: x.data, (args, kwargs) | ||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||
| meta["spec"] = pytree.tree_map(make_spec, op(*args_data, **kwargs_data)) | ||||||||||||||||||||||||||||||||||||||||
| return super().call_operator(op, args, kwargs, meta) | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| # pyre-ignore | ||||||||||||||||||||||||||||||||||||||||
| def call_getitem(self, value, key: int, meta): | ||||||||||||||||||||||||||||||||||||||||
| meta["spec"] = value.node.meta["spec"][key] | ||||||||||||||||||||||||||||||||||||||||
| return super().call_getitem(value, key, meta) | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| # pyre-ignore | ||||||||||||||||||||||||||||||||||||||||
| def call_cond(self, pred, true_fn, false_fn, inputs, meta): | ||||||||||||||||||||||||||||||||||||||||
| # true_fn/false_fn return tensors of the same shape, so we can pick | ||||||||||||||||||||||||||||||||||||||||
| # either one here. | ||||||||||||||||||||||||||||||||||||||||
| *_, true_out_node = true_fn.graph.nodes | ||||||||||||||||||||||||||||||||||||||||
| meta["spec"] = pytree.tree_map(make_spec, true_out_node.meta["val"]) | ||||||||||||||||||||||||||||||||||||||||
| return super().call_cond(pred, true_fn, false_fn, inputs, meta) | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| def call_while( | ||||||||||||||||||||||||||||||||||||||||
| self, | ||||||||||||||||||||||||||||||||||||||||
| cond_fn: torch.fx.GraphModule, | ||||||||||||||||||||||||||||||||||||||||
| body_fn: torch.fx.GraphModule, | ||||||||||||||||||||||||||||||||||||||||
| carried_inputs: List[ProxyValue], | ||||||||||||||||||||||||||||||||||||||||
| additional_inputs: List[ProxyValue], | ||||||||||||||||||||||||||||||||||||||||
| meta: NodeMetadata, | ||||||||||||||||||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||||||||||||||||||
| meta["spec"] = pytree.tree_map(make_spec, carried_inputs) | ||||||||||||||||||||||||||||||||||||||||
| return super().call_while( | ||||||||||||||||||||||||||||||||||||||||
| cond_fn, body_fn, carried_inputs, additional_inputs, meta | ||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
-107
to
-125
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why we don't have to handle condition and while anymore?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. They should be handled by having the tracing logic use ExportPass to regenerate the meta values and then assigning spec values for each node correspondingly. I did go ahead and specific tests for cond and while to verify that the specs are generated correctly. As long as the cond + while outputs don't alias anything else (my understanding is that this should be the case), it should be good. |
||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| def call_map( | ||||||||||||||||||||||||||||||||||||||||
| self, | ||||||||||||||||||||||||||||||||||||||||
| f: torch.fx.GraphModule, | ||||||||||||||||||||||||||||||||||||||||
| mapped_args: List[ProxyValue], | ||||||||||||||||||||||||||||||||||||||||
| operands: List[ProxyValue], | ||||||||||||||||||||||||||||||||||||||||
| meta: NodeMetadata, | ||||||||||||||||||||||||||||||||||||||||
| ) -> ProxyValue: | ||||||||||||||||||||||||||||||||||||||||
| mapped_dim_size = [arg.data for arg in mapped_args][0].size(0) | ||||||||||||||||||||||||||||||||||||||||
| *_, body_out_node = f.graph.nodes | ||||||||||||||||||||||||||||||||||||||||
| body_out_node_fake_tensor = body_out_node.meta["val"] | ||||||||||||||||||||||||||||||||||||||||
| map_fake_tensor = pytree.tree_map_only( | ||||||||||||||||||||||||||||||||||||||||
| torch.Tensor, | ||||||||||||||||||||||||||||||||||||||||
| lambda x: x.new_empty(mapped_dim_size, *x.shape), | ||||||||||||||||||||||||||||||||||||||||
| body_out_node_fake_tensor, | ||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||
| meta["spec"] = pytree.tree_map(make_spec, map_fake_tensor) | ||||||||||||||||||||||||||||||||||||||||
| return super().call_map(f, mapped_args, operands, meta) | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| # pyre-ignore | ||||||||||||||||||||||||||||||||||||||||
| def call_delegate(self, lowered_module, args, kwargs, meta): | ||||||||||||||||||||||||||||||||||||||||
| args_data, kwargs_data = pytree.tree_map_only( | ||||||||||||||||||||||||||||||||||||||||
| ProxyValue, lambda x: x.data, (args, kwargs) | ||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||
| # If spec is missing, re-genenrate it with args data | ||||||||||||||||||||||||||||||||||||||||
| if "spec" not in meta: | ||||||||||||||||||||||||||||||||||||||||
| meta["spec"] = pytree.tree_map( | ||||||||||||||||||||||||||||||||||||||||
| make_spec, | ||||||||||||||||||||||||||||||||||||||||
| executorch_call_delegate(lowered_module, *args_data), | ||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||
| return super().call_delegate(lowered_module, args, kwargs, meta) | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| # pyre-ignore | ||||||||||||||||||||||||||||||||||||||||
| def output(self, results, meta): | ||||||||||||||||||||||||||||||||||||||||
| # pyre-ignore | ||||||||||||||||||||||||||||||||||||||||
| def get_spec(x): | ||||||||||||||||||||||||||||||||||||||||
| if isinstance(x, ProxyValue): | ||||||||||||||||||||||||||||||||||||||||
| return x.node.meta["spec"] | ||||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||||
| return make_spec(x) | ||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||
| meta["spec"] = pytree.tree_map(get_spec, results) | ||||||||||||||||||||||||||||||||||||||||
| return super().output(results, meta) | ||||||||||||||||||||||||||||||||||||||||
Uh oh!
There was an error while loading. Please reload this page.