-
Notifications
You must be signed in to change notification settings - Fork 145
Fix inner graph inplace rewrites in Numba / PyTorch backends #1247
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix inner graph inplace rewrites in Numba / PyTorch backends #1247
Conversation
d49b1b0
to
38f544f
Compare
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #1247 +/- ##
=======================================
Coverage 81.98% 81.99%
=======================================
Files 188 188
Lines 48542 48551 +9
Branches 8675 8673 -2
=======================================
+ Hits 39799 39810 +11
+ Misses 6581 6579 -2
Partials 2162 2162
🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A bit confused by JAX is omitted here, when we specifically add such a supervisor in pymc during jax_funcify
Does this PR effectively close aesara-devs/aesara#637 ?
fgraph = op.fgraph | ||
add_supervisor_to_fgraph( | ||
fgraph=fgraph, | ||
input_specs=[In(x, borrow=True, mutable=False) for x in fgraph.inputs], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If this scan is an inner function to something else, do we still want mutable = False
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For now yes, this should be handled by inplace rewrites so that we know what inputs are safe to destroy by the time we get here. But I will still have to look at what those do exactly, and we weren't handling it before in the dispatch.
Scan has a very specific view of inplacing built around the constraint that they were compiling a full Pytensor function internally.
I'll open an issue to investigate
Also does |
That was added before we excluded inplace rewrites from JAX, it hasn't been needed for a while. Same with the Assert dispatch that's there |
Both do and both were handled in the backends that allow inplacing |
Only partially. It makes it easier with the refactor but doesn't do anything automatically. Perhaps we could add it by default and set all inputs to non-mutable but
Edit: I guess we can update the message to use the new helper |
JAX needs no special handling because it excludes inplace rewrites.
38f544f
to
561117e
Compare
JAX needs no special handling because it excludes inplace rewrites.
📚 Documentation preview 📚: https://pytensor--1247.org.readthedocs.build/en/1247/