-
Notifications
You must be signed in to change notification settings - Fork 155
Add rewrite to fuse nested BlockDiag Ops #1671
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Hey @jessegrabowski ! As suggested in the issue discussion, I’ve opened this draft PR to start working on the BlockDiag rewrite. |
|
I'd suggest you work in a test-driven way. Add a test to For an example of how to count ops in a graph for the test, look here (BUT the whole class is overkill for your case, just take the pieces from it and write an inline version). For a good rewrite template to get you started, I think this one is pretty readable. You will need to 1) check that the input is a BlockDiag op, 2) check that at least one of the inputs to the BlockDiag is a BlockDiag, 3) pull out the inputs from the inner BlockDiag, 4) make a new BlockDiag with |
|
@jessegrabowski ! Added a rewrite to fuse nested BlockDiagonal ops and updated test_linalg.py with a test for nested BlockDiagonal fusion. |
jessegrabowski
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Really great first pass. You're missing some of the boiler plate around rewrites, have a look here (or anywhere in this file really) to see how to register a rewrite, and how to tell it which Op to track (pay attention to the decorators).
You also need to use pytensor.function to compile your block diagonal graph and check that the rewrite was triggered, rather than calling it directly.
|
@jessegrabowski Added a rewrite to fuse nested BlockDiagonal ops into a single fused instance and included tests to verify fusion behavior, n_inputs, and output shape. |
bec4bd3 to
13b71f0
Compare
|
Oops I rebase-merged instead of squash. Anyway thanks @eby0303 and @jessegrabowski |
Description
This is a draft PR for issue #1593.
I’m setting up the local environment and exploring how to implement a rewrite that fuses nested BlockDiag Ops into a single one.
I’ll update this PR with code once the setup is complete and I have an initial version of the rewrite.
Related Issue
BlockDiagOps #1593Checklist
Type of change
📚 Documentation preview 📚: https://pytensor--1671.org.readthedocs.build/en/1671/