|
10 | 10 | import pytensor
|
11 | 11 | import pytensor.tensor as pt
|
12 | 12 | from pytensor import config
|
| 13 | +from pytensor.tensor.slinalg import SolveTriangular |
13 | 14 | from tests import unittest_tools as utt
|
14 | 15 | from tests.link.numba.test_basic import compare_numba_and_py
|
15 | 16 |
|
@@ -130,6 +131,48 @@ def A_func_pt(x):
|
130 | 131 | )
|
131 | 132 |
|
132 | 133 |
|
| 134 | +@pytest.mark.parametrize("overwrite_b", [True, False], ids=["inplace", "not_inplace"]) |
| 135 | +def test_solve_triangular_overwrite_b_correct(overwrite_b): |
| 136 | + # Regression test for issue #1233 |
| 137 | + |
| 138 | + rng = np.random.default_rng(utt.fetch_seed()) |
| 139 | + a_test_py = np.asfortranarray(rng.normal(size=(3, 3))) |
| 140 | + a_test_py = np.tril(a_test_py) |
| 141 | + b_test_py = np.asfortranarray(rng.normal(size=(3, 2))) |
| 142 | + |
| 143 | + # .T.copy().T creates an f-contiguous copy of an f-contiguous array (otherwise the copy is c-contiguous) |
| 144 | + a_test_nb = a_test_py.copy(order="F") |
| 145 | + b_test_nb = b_test_py.copy(order="F") |
| 146 | + |
| 147 | + op = SolveTriangular( |
| 148 | + trans=0, |
| 149 | + unit_diagonal=False, |
| 150 | + lower=False, |
| 151 | + check_finite=True, |
| 152 | + b_ndim=2, |
| 153 | + overwrite_b=overwrite_b, |
| 154 | + ) |
| 155 | + |
| 156 | + a_pt = pt.matrix("a", shape=(3, 3)) |
| 157 | + b_pt = pt.matrix("b", shape=(3, 2)) |
| 158 | + out = op(a_pt, b_pt) |
| 159 | + |
| 160 | + py_fn = pytensor.function([a_pt, b_pt], out, accept_inplace=True) |
| 161 | + numba_fn = pytensor.function([a_pt, b_pt], out, accept_inplace=True, mode="NUMBA") |
| 162 | + |
| 163 | + x_py = py_fn(a_test_py, b_test_py) |
| 164 | + x_nb = numba_fn(a_test_nb, b_test_nb) |
| 165 | + |
| 166 | + np.testing.assert_allclose( |
| 167 | + py_fn(a_test_py, b_test_py), numba_fn(a_test_nb, b_test_nb) |
| 168 | + ) |
| 169 | + np.testing.assert_allclose(b_test_py, b_test_nb) |
| 170 | + |
| 171 | + if overwrite_b: |
| 172 | + np.testing.assert_allclose(b_test_py, x_py) |
| 173 | + np.testing.assert_allclose(b_test_nb, x_nb) |
| 174 | + |
| 175 | + |
133 | 176 | @pytest.mark.parametrize("value", [np.nan, np.inf])
|
134 | 177 | @pytest.mark.filterwarnings(
|
135 | 178 | 'ignore:Cannot cache compiled function "numba_funcified_fgraph"'
|
|
0 commit comments