Skip to content

Conversation

fonnesbeck
Copy link
Member

@fonnesbeck fonnesbeck commented Feb 20, 2025

Description

Updated logsumexp to protect against overflow.

Related Issue

  • Closes #
  • Related to #

Checklist

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

📚 Documentation preview 📚: https://pytensor--1227.org.readthedocs.build/en/1227/

Copy link

codecov bot commented Feb 20, 2025

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 82.01%. Comparing base (3cdcfde) to head (150f908).
Report is 148 commits behind head on main.

Additional details and impacted files

Impacted file tree graph

@@           Coverage Diff           @@
##             main    #1227   +/-   ##
=======================================
  Coverage   82.01%   82.01%           
=======================================
  Files         188      188           
  Lines       48561    48562    +1     
  Branches     8679     8679           
=======================================
+ Hits        39826    39827    +1     
  Misses       6572     6572           
  Partials     2163     2163           
Files with missing lines Coverage Δ
pytensor/tensor/math.py 92.02% <100.00%> (+<0.01%) ⬆️
🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@ricardoV94
Copy link
Member

We do it during rewrites:

import pytensor
import pytensor.tensor as pt

x = pt.vector("x")
y = pt.logsumexp(x)

with pytensor.config.change_flags(optimizer_verbose=True):
    # rewriting: rewrite local_log_sum_exp replaces Log.0 of Log(Sum{axes=None}.0) with Add.0 of Add(Max{axes=None}.0, Log.0)
    pytensor.graph.rewrite_graph(y, include=("stabilize",)).dprint()
    

You can add a comment instead if it's suprising looking at the source code

Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Already done in rewrites


return log(sum(exp(x), axis=axis, keepdims=keepdims))
result = log(
sum(exp(x - max(x, axis=axis, keepdims=True)), axis=axis, keepdims=keepdims)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's also a bit more complicated than that to handle cases of all infinities. The graph is pretty messy, that's why we do it later only

@fonnesbeck
Copy link
Member Author

TIL

@fonnesbeck fonnesbeck closed this Feb 20, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants