-
Notifications
You must be signed in to change notification settings - Fork 149
Constant fold branches of variadic add/mul #1422
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Constant fold branches of variadic add/mul #1422
Conversation
lucianopaz
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don’t know why the test fails but it looks like the fusion rewrite is applied only once. Maybe the equilibrium rewrite that you took out should be added back in?
|
@ricardoV94, I just went through your branch's code and found that the error is coming from the fact that the |
ca12b58 to
082e1b7
Compare
| include=[ | ||
| "canonicalize", | ||
| "fusion", | ||
| "add_mul_flat", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was the change needed to get the test to pass @lucianopaz
|
@lucianopaz I came to the same conclusion, I just added the rewrite explicitly. Mentioned in an inline comment above |
082e1b7 to
70db72e
Compare
lucianopaz
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks @ricardoV94 !
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #1422 +/- ##
==========================================
+ Coverage 82.11% 82.13% +0.01%
==========================================
Files 211 211
Lines 49743 49773 +30
Branches 8824 8830 +6
==========================================
+ Hits 40847 40879 +32
+ Misses 6715 6714 -1
+ Partials 2181 2180 -1
🚀 New features to boost your workflow:
|
Refactoring and renaming:
local_add_mul_fusionfunction toflatten_nested_add_multo more precisely reflect how it works (one could also fuse non-nested add/mul, like the FusionOptimizer does). The function now explicitly tracksaddandmuloperations instead of relying on genericElemwisechecks. [1] [2] [3]New optimization for constant folding:
constant_fold_branches_of_add_mul, which folds constants in add/mul operations when it does not result in higher intermediate memory usage. This optimization is registered in a new sequence database,add_mul_flat_seqopt, which runs before generic elementwise fusion.The two rewrites are pulled out to a separate database so it's included in JAX rewrites (JAX does not include fusion rewrites). We've found this could help avoding XLA constant fold (CC @lucianopaz)
📚 Documentation preview 📚: https://pytensor--1422.org.readthedocs.build/en/1422/