|
7 | 7 | import inspect
|
8 | 8 |
|
9 | 9 | import numpy as np
|
| 10 | +from scipy.special import logsumexp as scipy_logsumexp |
10 | 11 | from xarray import DataArray
|
11 | 12 |
|
12 | 13 | import pytensor.scalar as ps
|
13 | 14 | import pytensor.xtensor.math as pxm
|
14 | 15 | from pytensor import function
|
15 | 16 | from pytensor.scalar import ScalarOp
|
16 | 17 | from pytensor.xtensor.basic import rename
|
17 |
| -from pytensor.xtensor.math import add, exp |
| 18 | +from pytensor.xtensor.math import add, exp, logsumexp |
18 | 19 | from pytensor.xtensor.type import xtensor
|
19 | 20 | from tests.xtensor.util import xr_arange_like, xr_assert_allclose, xr_function
|
20 | 21 |
|
@@ -152,6 +153,28 @@ def test_cast():
|
152 | 153 | yc64.astype("float64")
|
153 | 154 |
|
154 | 155 |
|
| 156 | +@pytest.mark.parametrize( |
| 157 | + ["shape", "dims", "axis"], |
| 158 | + [ |
| 159 | + ((3, 4), ("a", "b"), None), |
| 160 | + ((3, 4), "a", 0), |
| 161 | + ((3, 4), "b", 1), |
| 162 | + ], |
| 163 | +) |
| 164 | +def test_logsumexp(shape, dims, axis): |
| 165 | + scipy_inp = np.zeros(shape) |
| 166 | + scipy_out = scipy_logsumexp(scipy_inp, axis=axis) |
| 167 | + |
| 168 | + pytensor_inp = DataArray(scipy_inp, dims=("a", "b")) |
| 169 | + f = function([], logsumexp(pytensor_inp, dim=dims)) |
| 170 | + pytensor_out = f() |
| 171 | + |
| 172 | + np.testing.assert_array_almost_equal( |
| 173 | + pytensor_out, |
| 174 | + scipy_out, |
| 175 | + ) |
| 176 | + |
| 177 | + |
155 | 178 | def test_dot():
|
156 | 179 | """Test basic dot product operations."""
|
157 | 180 | # Test matrix-vector dot product (with multiple-letter dim names)
|
|
0 commit comments