Skip to content

Conversation

ricardoV94
Copy link
Member

@ricardoV94 ricardoV94 commented Feb 25, 2025

JAX needs no special handling because it excludes inplace rewrites.


📚 Documentation preview 📚: https://pytensor--1247.org.readthedocs.build/en/1247/

@ricardoV94 ricardoV94 force-pushed the inner_graph_inplace_rewrites branch 2 times, most recently from d49b1b0 to 38f544f Compare February 25, 2025 18:48
Copy link

codecov bot commented Feb 25, 2025

Codecov Report

Attention: Patch coverage is 92.10526% with 3 lines in your changes missing coverage. Please review.

Project coverage is 81.99%. Comparing base (f12bea6) to head (561117e).
Report is 146 commits behind head on main.

Files with missing lines Patch % Lines
pytensor/compile/function/types.py 76.92% 1 Missing and 2 partials ⚠️
Additional details and impacted files

Impacted file tree graph

@@           Coverage Diff           @@
##             main    #1247   +/-   ##
=======================================
  Coverage   81.98%   81.99%           
=======================================
  Files         188      188           
  Lines       48542    48551    +9     
  Branches     8675     8673    -2     
=======================================
+ Hits        39799    39810   +11     
+ Misses       6581     6579    -2     
  Partials     2162     2162           
Files with missing lines Coverage Δ
pytensor/compile/mode.py 84.72% <ø> (ø)
pytensor/link/jax/dispatch/scan.py 100.00% <100.00%> (ø)
pytensor/link/numba/dispatch/basic.py 79.08% <100.00%> (+0.16%) ⬆️
pytensor/link/numba/dispatch/scan.py 95.97% <100.00%> (+0.11%) ⬆️
pytensor/link/numba/dispatch/subtensor.py 95.34% <100.00%> (ø)
pytensor/link/pytorch/dispatch/basic.py 87.87% <100.00%> (+0.47%) ⬆️
pytensor/scan/op.py 84.73% <100.00%> (+0.11%) ⬆️
pytensor/sparse/rewriting.py 76.15% <ø> (ø)
pytensor/compile/function/types.py 80.68% <76.92%> (+0.02%) ⬆️

... and 2 files with indirect coverage changes

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Copy link
Member

@jessegrabowski jessegrabowski left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A bit confused by JAX is omitted here, when we specifically add such a supervisor in pymc during jax_funcify

Does this PR effectively close aesara-devs/aesara#637 ?

fgraph = op.fgraph
add_supervisor_to_fgraph(
fgraph=fgraph,
input_specs=[In(x, borrow=True, mutable=False) for x in fgraph.inputs],
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this scan is an inner function to something else, do we still want mutable = False ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For now yes, this should be handled by inplace rewrites so that we know what inputs are safe to destroy by the time we get here. But I will still have to look at what those do exactly, and we weren't handling it before in the dispatch.

Scan has a very specific view of inplacing built around the constraint that they were compiling a full Pytensor function internally.

I'll open an issue to investigate

@jessegrabowski
Copy link
Member

Also does OpFromGraph require any special consideration here? Or only scan?

@ricardoV94
Copy link
Member Author

A bit confused by JAX is omitted here, when we specifically add such a supervisor in pymc during jax_funcify

That was added before we excluded inplace rewrites from JAX, it hasn't been needed for a while. Same with the Assert dispatch that's there

@ricardoV94
Copy link
Member Author

Also does OpFromGraph require any special consideration here? Or only scan?

Both do and both were handled in the backends that allow inplacing

@ricardoV94
Copy link
Member Author

ricardoV94 commented Feb 26, 2025

Does this PR effectively close aesara-devs/aesara#637 ?

Only partially. It makes it easier with the refactor but doesn't do anything automatically. Perhaps we could add it by default and set all inputs to non-mutable but

  1. that may slow down user guided rewrites that shouldn't add inplace stuff anyway and
  2. I'm not sure the Supervisor/Destroy handler can handle (pun pun) a change in the protected inputs, so it may have needed more work to use the default of all non mutable. For sure more testing.

Edit: I guess we can update the message to use the new helper

@ricardoV94 ricardoV94 force-pushed the inner_graph_inplace_rewrites branch from 38f544f to 561117e Compare February 27, 2025 11:08
@ricardoV94 ricardoV94 merged commit 69efc68 into pymc-devs:main Feb 27, 2025
72 of 73 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants