-
Notifications
You must be signed in to change notification settings - Fork 149
Robust logsumexp #1227
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Robust logsumexp #1227
Conversation
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## main #1227 +/- ##
=======================================
Coverage 82.01% 82.01%
=======================================
Files 188 188
Lines 48561 48562 +1
Branches 8679 8679
=======================================
+ Hits 39826 39827 +1
Misses 6572 6572
Partials 2163 2163
🚀 New features to boost your workflow:
|
|
We do it during rewrites: import pytensor
import pytensor.tensor as pt
x = pt.vector("x")
y = pt.logsumexp(x)
with pytensor.config.change_flags(optimizer_verbose=True):
# rewriting: rewrite local_log_sum_exp replaces Log.0 of Log(Sum{axes=None}.0) with Add.0 of Add(Max{axes=None}.0, Log.0)
pytensor.graph.rewrite_graph(y, include=("stabilize",)).dprint()
You can add a comment instead if it's suprising looking at the source code |
ricardoV94
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Already done in rewrites
|
|
||
| return log(sum(exp(x), axis=axis, keepdims=keepdims)) | ||
| result = log( | ||
| sum(exp(x - max(x, axis=axis, keepdims=True)), axis=axis, keepdims=keepdims) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it's also a bit more complicated than that to handle cases of all infinities. The graph is pretty messy, that's why we do it later only
|
TIL |
Description
Updated logsumexp to protect against overflow.
Related Issue
Checklist
Type of change
📚 Documentation preview 📚: https://pytensor--1227.org.readthedocs.build/en/1227/