-
Notifications
You must be signed in to change notification settings - Fork 748
Fix double-tracing in SpecPropPass #15485
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/15485
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (1 Unrelated Failure)As of commit 7da6e25 with merge base 18c1c5b ( BROKEN TRUNK - The following job failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
@GregoryComer has exported this pull request. If you are a Meta employee, you can view the originating Diff in D85913581. |
db9ef9c to
3944aac
Compare
Summary: Pull Request resolved: #15485 Differential Revision: D85913581
Summary: Pull Request resolved: pytorch#15485 Differential Revision: D85913581
3944aac to
6c9d1cb
Compare
Summary: Pull Request resolved: pytorch#15485 Differential Revision: D85913581
6c9d1cb to
8213149
Compare
|
One additional note is that aliasing analysis feels pretty fragile as is. I fixed several subtle issues (luckily caught by CI) where my changes were accidentally planning two seperate tensors when they should alias / share one TensorSpec. I'm wondering if we should re-write this pass again to either rely on ProxyValue reference equality or otherwise introduce some proper aliasing analysis. This is as opposed to hard coding that getitem and output, for example, always alias their argument. This seems like it could get messy with non-functional custom ops or defunctionalization, in general. @JacobSzwejbka @angelayi what are your thoughts on this? |
8213149 to
01418d5
Compare
Summary: Pull Request resolved: pytorch#15485 Differential Revision: D85913581
Summary: Pull Request resolved: pytorch#15485 Differential Revision: D85913581
01418d5 to
0152b4a
Compare
Summary: Pull Request resolved: pytorch#15485 Differential Revision: D85913581
0152b4a to
83de667
Compare
|
@GregoryComer has imported this pull request. If you are a Meta employee, you can view this in D85913581. |
Summary: Pull Request resolved: pytorch#15485 Differential Revision: D85913581
83de667 to
19eef20
Compare
Summary: Pull Request resolved: pytorch#15485 Differential Revision: D85913581
19eef20 to
1996324
Compare
|
Note that the moshi and zephyr size test failures are pre-existing. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR fixes a double-tracing issue in SpecPropPass where tensor specs were generated with symints from a different trace than the one used for guards, causing guards on unbacked symints to be lost. The fix refactors SpecPropPass to perform a single re-trace using the parent ExportPass class and then generate specs from the resulting metadata, ensuring consistency between specs and guards.
Key changes:
- Rewrote
SpecPropPass.__call__()to re-trace once and populate specs from meta values - Removed individual node handler methods (placeholder, call_operator, call_getitem, etc.) in favor of unified spec generation
- Added test case with custom op using unbacked symints to verify guard propagation
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 3 comments.
| File | Description |
|---|---|
| exir/passes/spec_prop_pass.py | Complete rewrite of SpecPropPass to use single re-trace strategy; replaces per-node callbacks with post-trace spec generation from meta values |
| exir/tests/test_passes.py | Adds custom ops (unbacked, unbacked.out) and test case to verify spec propagation correctly captures guards for unbacked symints |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| if "spec" not in node.meta: | ||
| node.meta["spec"] = pytree.tree_map(make_spec, meta_val) | ||
| else: | ||
| node.meta["spec"] = pytree.tree_map(make_spec, meta_val) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider merging these 2 conditions?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is some weird existing behavior here that seems to need to be preserved (barring a larger update). Basically, we don't want to regenerate call_delegate node specs but do want to regenerate everything else. I'll add a comment detailing why.
| def call_cond(self, pred, true_fn, false_fn, inputs, meta): | ||
| # true_fn/false_fn return tensors of the same shape, so we can pick | ||
| # either one here. | ||
| *_, true_out_node = true_fn.graph.nodes | ||
| meta["spec"] = pytree.tree_map(make_spec, true_out_node.meta["val"]) | ||
| return super().call_cond(pred, true_fn, false_fn, inputs, meta) | ||
|
|
||
| def call_while( | ||
| self, | ||
| cond_fn: torch.fx.GraphModule, | ||
| body_fn: torch.fx.GraphModule, | ||
| carried_inputs: List[ProxyValue], | ||
| additional_inputs: List[ProxyValue], | ||
| meta: NodeMetadata, | ||
| ): | ||
| meta["spec"] = pytree.tree_map(make_spec, carried_inputs) | ||
| return super().call_while( | ||
| cond_fn, body_fn, carried_inputs, additional_inputs, meta | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why we don't have to handle condition and while anymore?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
They should be handled by having the tracing logic use ExportPass to regenerate the meta values and then assigning spec values for each node correspondingly. I did go ahead and specific tests for cond and while to verify that the specs are generated correctly. As long as the cond + while outputs don't alias anything else (my understanding is that this should be the case), it should be good.
1996324 to
96e10d5
Compare
Summary: Our current SpecPropPass doesn't properly capture the effect of guards in the shape environment due to double-tracing certain ops. The problem looks like this: * Every time we trace through the graph, we generate new symints. * That's fine, since shape_env will pick up guards during the retrace. * Problem is that SpecPropPass does this twice. Once to generate the spec and then once by calling super().call_operator(...) ([https://github.com/.../exir/passes/spec_prop_pass.py...](https://github.com/pytorch/executorch/blob/11f752cf84b296a39c0b74b889d618f279bc8186/exir/passes/spec_prop_pass.py#L98)). * The tensor spec gets the symint from the first. But the graph and guards use the second. * Hence the tensor spec doesn't pick up on guards. To resolve this, I've updated the SpecPropPass to re-trace the graph and then generate specs based on the meta values, not the traced ProxyValues (thanks angelayi for the suggestion). This resolves the issue. I originally saw this issue with the NMS torchvision op, but to avoid adding a new dep to the core EXIR tests, I've written a test with a custom op that uses an unbacked symint in the meta kernel output shape to replicate the bug in the same way. Differential Revision: D85913581 Pulled By: GregoryComer
Summary: Our current SpecPropPass doesn't properly capture the effect of guards in the shape environment due to double-tracing certain ops. The problem looks like this: * Every time we trace through the graph, we generate new symints. * That's fine, since shape_env will pick up guards during the retrace. * Problem is that SpecPropPass does this twice. Once to generate the spec and then once by calling super().call_operator(...) ([https://github.com/.../exir/passes/spec_prop_pass.py...](https://github.com/pytorch/executorch/blob/11f752cf84b296a39c0b74b889d618f279bc8186/exir/passes/spec_prop_pass.py#L98)). * The tensor spec gets the symint from the first. But the graph and guards use the second. * Hence the tensor spec doesn't pick up on guards. To resolve this, I've updated the SpecPropPass to re-trace the graph and then generate specs based on the meta values, not the traced ProxyValues (thanks angelayi for the suggestion). This resolves the issue. I originally saw this issue with the NMS torchvision op, but to avoid adding a new dep to the core EXIR tests, I've written a test with a custom op that uses an unbacked symint in the meta kernel output shape to replicate the bug in the same way. Differential Revision: D85913581 Pulled By: GregoryComer
96e10d5 to
dcf957f
Compare
Summary: Our current SpecPropPass doesn't properly capture the effect of guards in the shape environment due to double-tracing certain ops. The problem looks like this: * Every time we trace through the graph, we generate new symints. * That's fine, since shape_env will pick up guards during the retrace. * Problem is that SpecPropPass does this twice. Once to generate the spec and then once by calling super().call_operator(...) ([https://github.com/.../exir/passes/spec_prop_pass.py...](https://github.com/pytorch/executorch/blob/11f752cf84b296a39c0b74b889d618f279bc8186/exir/passes/spec_prop_pass.py#L98)). * The tensor spec gets the symint from the first. But the graph and guards use the second. * Hence the tensor spec doesn't pick up on guards. To resolve this, I've updated the SpecPropPass to re-trace the graph and then generate specs based on the meta values, not the traced ProxyValues (thanks angelayi for the suggestion). This resolves the issue. I originally saw this issue with the NMS torchvision op, but to avoid adding a new dep to the core EXIR tests, I've written a test with a custom op that uses an unbacked symint in the meta kernel output shape to replicate the bug in the same way. Differential Revision: D85913581 Pulled By: GregoryComer
dcf957f to
b3cc905
Compare
| elif ( | ||
| node.op == "call_function" | ||
| and node.target == executorch_call_delegate | ||
| ): | ||
| # Note: We currently rely on delegate node specs not being regenerated, | ||
| # as the spec is set somewhat manually when adding the call delegate node. | ||
| # If we regenerate, it can change and break lowering (it becomes a tuple?). | ||
| # Ideally, we should figure out how to make the spec regeneration not break | ||
| # things. | ||
| # | ||
| # We do need to regenerate non-call-delegate node specs, as this pass is called | ||
| # multiple times in some lowering paths (backends can and do call it). | ||
| if "spec" not in node.meta: | ||
| node.meta["spec"] = pytree.tree_map(make_spec, meta_val) | ||
| else: | ||
| node.meta["spec"] = pytree.tree_map(make_spec, meta_val) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I meant something like:
| elif ( | |
| node.op == "call_function" | |
| and node.target == executorch_call_delegate | |
| ): | |
| # Note: We currently rely on delegate node specs not being regenerated, | |
| # as the spec is set somewhat manually when adding the call delegate node. | |
| # If we regenerate, it can change and break lowering (it becomes a tuple?). | |
| # Ideally, we should figure out how to make the spec regeneration not break | |
| # things. | |
| # | |
| # We do need to regenerate non-call-delegate node specs, as this pass is called | |
| # multiple times in some lowering paths (backends can and do call it). | |
| if "spec" not in node.meta: | |
| node.meta["spec"] = pytree.tree_map(make_spec, meta_val) | |
| else: | |
| node.meta["spec"] = pytree.tree_map(make_spec, meta_val) | |
| else: | |
| if "spec" not in node.meta: | |
| node.meta["spec"] = pytree.tree_map(make_spec, meta_val) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure - the issue that I've seen is that sometimes this pass gets called multiple times (Cadence backend does this, for example) and thus we need to regenerate the spec for most nodes to make sure they pick up on any shape changes between calls.
But if we regenerate the spec for call_delegate nodes, it breaks things. So the if "spec" not in node.meta: condition should only apply to call_delegate but not anything else. Otherwise is breaks existing backend assumptions.
Ideally, we'll do a deeper change to fix this but this preserves the existing behavior. I could change the line to if "spec" not in node.meta or node.target != executorch_call_delegate if you'd prefer.
Summary: Our current SpecPropPass doesn't properly capture the effect of guards in the shape environment due to double-tracing certain ops. The problem looks like this: * Every time we trace through the graph, we generate new symints. * That's fine, since shape_env will pick up guards during the retrace. * Problem is that SpecPropPass does this twice. Once to generate the spec and then once by calling super().call_operator(...) ([https://github.com/.../exir/passes/spec_prop_pass.py...](https://github.com/pytorch/executorch/blob/11f752cf84b296a39c0b74b889d618f279bc8186/exir/passes/spec_prop_pass.py#L98)). * The tensor spec gets the symint from the first. But the graph and guards use the second. * Hence the tensor spec doesn't pick up on guards. To resolve this, I've updated the SpecPropPass to re-trace the graph and then generate specs based on the meta values, not the traced ProxyValues (thanks angelayi for the suggestion). This resolves the issue. I originally saw this issue with the NMS torchvision op, but to avoid adding a new dep to the core EXIR tests, I've written a test with a custom op that uses an unbacked symint in the meta kernel output shape to replicate the bug in the same way. Differential Revision: D85913581 Pulled By: GregoryComer
b3cc905 to
7da6e25
Compare
Our current SpecPropPass doesn't properly capture the effect of guards in the shape environment due to double-tracing certain ops. The problem looks like this:
To resolve this, I've updated the SpecPropPass to re-trace the graph and then generate specs based on the meta values, not the traced ProxyValues (thanks @angelayi for the suggestion). This resolves the issue.
I originally saw this issue with the NMS torchvision op, but to avoid adding a new dep to the core EXIR tests, I've written a test with a custom op that uses an unbacked symint in the meta kernel output shape to replicate the bug in the same way.
Differential Revision: D85913581