Skip to content

Conversation

@eby0303
Copy link
Contributor

@eby0303 eby0303 commented Oct 16, 2025

Description

This is a draft PR for issue #1593.
I’m setting up the local environment and exploring how to implement a rewrite that fuses nested BlockDiag Ops into a single one.
I’ll update this PR with code once the setup is complete and I have an initial version of the rewrite.

Related Issue

Checklist

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

📚 Documentation preview 📚: https://pytensor--1671.org.readthedocs.build/en/1671/

@eby0303
Copy link
Contributor Author

eby0303 commented Oct 16, 2025

Hey @jessegrabowski ! As suggested in the issue discussion, I’ve opened this draft PR to start working on the BlockDiag rewrite.
I’ll be setting up locally. Please feel free to share any tips or guidance for where to begin
Will update this PR as I make progress.

@jessegrabowski
Copy link
Member

I'd suggest you work in a test-driven way. Add a test to tests/tensor/rewriting/test_linalg.py with a simple nested blockwise and count the number of BlockDiag ops, and assert that there is only 1. Confim that this test fails. Then add a rewrite to tensor/rewriting/linalg.py that looks for a blockwise with a blockwise inside, and if so merges them.

For an example of how to count ops in a graph for the test, look here (BUT the whole class is overkill for your case, just take the pieces from it and write an inline version).

For a good rewrite template to get you started, I think this one is pretty readable. You will need to 1) check that the input is a BlockDiag op, 2) check that at least one of the inputs to the BlockDiag is a BlockDiag, 3) pull out the inputs from the inner BlockDiag, 4) make a new BlockDiag with n_inputs = old_n_inputs + 1 and return it, passing in all 3 inputs.

@eby0303
Copy link
Contributor Author

eby0303 commented Oct 16, 2025

@jessegrabowski ! Added a rewrite to fuse nested BlockDiagonal ops and updated test_linalg.py with a test for nested BlockDiagonal fusion.

Copy link
Member

@jessegrabowski jessegrabowski left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Really great first pass. You're missing some of the boiler plate around rewrites, have a look here (or anywhere in this file really) to see how to register a rewrite, and how to tell it which Op to track (pay attention to the decorators).

You also need to use pytensor.function to compile your block diagonal graph and check that the rewrite was triggered, rather than calling it directly.

@eby0303
Copy link
Contributor Author

eby0303 commented Nov 4, 2025

@jessegrabowski Added a rewrite to fuse nested BlockDiagonal ops into a single fused instance and included tests to verify fusion behavior, n_inputs, and output shape.

@jessegrabowski jessegrabowski force-pushed the fuse-blockdiag-rewrite branch from bec4bd3 to 13b71f0 Compare January 8, 2026 03:48
@jessegrabowski jessegrabowski marked this pull request as ready for review January 8, 2026 03:48
@jessegrabowski jessegrabowski changed the title WIP: Add rewrite to fuse nested BlockDiag Ops Add rewrite to fuse nested BlockDiag Ops Jan 8, 2026
@ricardoV94 ricardoV94 merged commit 945e979 into pymc-devs:main Jan 8, 2026
66 checks passed
@ricardoV94
Copy link
Member

Oops I rebase-merged instead of squash. Anyway thanks @eby0303 and @jessegrabowski

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Add rewrite to fuse nested BlockDiag Ops

3 participants