|
30 | 30 | from pytensor.tensor.math import max as pt_max |
31 | 31 | from pytensor.tensor.math import min as pt_min |
32 | 32 | from pytensor.tensor.math import sum as pt_sum |
33 | | -from pytensor.tensor.special import SoftmaxGrad, softmax |
| 33 | +from pytensor.tensor.special import SoftmaxGrad, log_softmax, softmax |
34 | 34 | from pytensor.tensor.type import matrix, vector, vectors |
35 | 35 | from tests.link.mlx.test_basic import compare_mlx_and_py |
36 | 36 |
|
@@ -97,6 +97,15 @@ def test_softmax_grad(axis): |
97 | 97 | compare_mlx_and_py([dy, sm], [out], [dy_test_value, sm_test_value]) |
98 | 98 |
|
99 | 99 |
|
| 100 | +@pytest.mark.parametrize("axis", [None, 0, 1]) |
| 101 | +def test_logsoftmax(axis): |
| 102 | + x = matrix("x") |
| 103 | + x_test_value = np.arange(6, dtype=config.floatX).reshape(2, 3) |
| 104 | + out = log_softmax(x, axis=axis) |
| 105 | + |
| 106 | + compare_mlx_and_py([x], [out], [x_test_value]) |
| 107 | + |
| 108 | + |
100 | 109 | @pytest.mark.parametrize("size", [(10, 10), (1000, 1000)]) |
101 | 110 | @pytest.mark.parametrize("axis", [0, 1]) |
102 | 111 | def test_logsumexp_benchmark(size, axis, benchmark): |
|
0 commit comments