Skip to content

Commit 1d825dd

Browse files
zaxtaxricardoV94
authored andcommitted
Add logsumexp to xtensor
1 parent 4312d8c commit 1d825dd

File tree

2 files changed

+29
-1
lines changed

2 files changed

+29
-1
lines changed

pytensor/xtensor/math.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -512,6 +512,11 @@ def softmax(x, dim=None):
512512
return exp_x / exp_x.sum(dim=dim)
513513

514514

515+
def logsumexp(x, dim=None):
516+
"""Compute the logsumexp of an XTensorVariable along a specified dimension."""
517+
return log(exp(x).sum(dim=dim))
518+
519+
515520
class Dot(XOp):
516521
"""Matrix multiplication between two XTensorVariables.
517522

tests/xtensor/test_math.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,15 @@
77
import inspect
88

99
import numpy as np
10+
from scipy.special import logsumexp as scipy_logsumexp
1011
from xarray import DataArray
1112

1213
import pytensor.scalar as ps
1314
import pytensor.xtensor.math as pxm
1415
from pytensor import function
1516
from pytensor.scalar import ScalarOp
1617
from pytensor.xtensor.basic import rename
17-
from pytensor.xtensor.math import add, exp
18+
from pytensor.xtensor.math import add, exp, logsumexp
1819
from pytensor.xtensor.type import xtensor
1920
from tests.xtensor.util import xr_arange_like, xr_assert_allclose, xr_function
2021

@@ -152,6 +153,28 @@ def test_cast():
152153
yc64.astype("float64")
153154

154155

156+
@pytest.mark.parametrize(
157+
["shape", "dims", "axis"],
158+
[
159+
((3, 4), ("a", "b"), None),
160+
((3, 4), "a", 0),
161+
((3, 4), "b", 1),
162+
],
163+
)
164+
def test_logsumexp(shape, dims, axis):
165+
scipy_inp = np.zeros(shape)
166+
scipy_out = scipy_logsumexp(scipy_inp, axis=axis)
167+
168+
pytensor_inp = DataArray(scipy_inp, dims=("a", "b"))
169+
f = function([], logsumexp(pytensor_inp, dim=dims))
170+
pytensor_out = f()
171+
172+
np.testing.assert_array_almost_equal(
173+
pytensor_out,
174+
scipy_out,
175+
)
176+
177+
155178
def test_dot():
156179
"""Test basic dot product operations."""
157180
# Test matrix-vector dot product (with multiple-letter dim names)

0 commit comments

Comments
 (0)