Skip to content

Commit 67cce84

Browse files
Add regression test
1 parent 17c62f4 commit 67cce84

File tree

1 file changed

+43
-0
lines changed

1 file changed

+43
-0
lines changed

tests/link/numba/test_slinalg.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import pytensor
1111
import pytensor.tensor as pt
1212
from pytensor import config
13+
from pytensor.tensor.slinalg import SolveTriangular
1314
from tests import unittest_tools as utt
1415
from tests.link.numba.test_basic import compare_numba_and_py
1516

@@ -130,6 +131,48 @@ def A_func_pt(x):
130131
)
131132

132133

134+
@pytest.mark.parametrize("overwrite_b", [True, False], ids=["inplace", "not_inplace"])
135+
def test_solve_triangular_overwrite_b_correct(overwrite_b):
136+
# Regression test for issue #1233
137+
138+
rng = np.random.default_rng(utt.fetch_seed())
139+
a_test_py = np.asfortranarray(rng.normal(size=(3, 3)))
140+
a_test_py = np.tril(a_test_py)
141+
b_test_py = np.asfortranarray(rng.normal(size=(3, 2)))
142+
143+
# .T.copy().T creates an f-contiguous copy of an f-contiguous array (otherwise the copy is c-contiguous)
144+
a_test_nb = a_test_py.T.copy().T
145+
b_test_nb = b_test_py.T.copy().T
146+
147+
op = SolveTriangular(
148+
trans=0,
149+
unit_diagonal=False,
150+
lower=False,
151+
check_finite=True,
152+
b_ndim=2,
153+
overwrite_b=overwrite_b,
154+
)
155+
156+
a_pt = pt.matrix("a", shape=(3, 3))
157+
b_pt = pt.matrix("b", shape=(3, 2))
158+
out = op(a_pt, b_pt)
159+
160+
py_fn = pytensor.function([a_pt, b_pt], out, accept_inplace=True)
161+
numba_fn = pytensor.function([a_pt, b_pt], out, accept_inplace=True, mode="NUMBA")
162+
163+
x_py = py_fn(a_test_py, b_test_py)
164+
x_nb = numba_fn(a_test_nb, b_test_nb)
165+
166+
np.testing.assert_allclose(
167+
py_fn(a_test_py, b_test_py), numba_fn(a_test_nb, b_test_nb)
168+
)
169+
np.testing.assert_allclose(b_test_py, b_test_nb)
170+
171+
if overwrite_b:
172+
np.testing.assert_allclose(b_test_py, x_py)
173+
np.testing.assert_allclose(b_test_nb, x_nb)
174+
175+
133176
@pytest.mark.parametrize("value", [np.nan, np.inf])
134177
@pytest.mark.filterwarnings(
135178
'ignore:Cannot cache compiled function "numba_funcified_fgraph"'

0 commit comments

Comments
 (0)