Skip to content

Commit cbae16b

Browse files
committed
Remove redundant numba lapack tests
Already covered in the Solve tests
1 parent f330a9f commit cbae16b

File tree

1 file changed

+0
-72
lines changed

1 file changed

+0
-72
lines changed

tests/link/numba/test_slinalg.py

Lines changed: 0 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44

55
import numpy as np
66
import pytest
7-
from numpy.testing import assert_allclose
87

98
import pytensor
109
import pytensor.tensor as pt
@@ -328,77 +327,6 @@ def gecon(x, norm):
328327
np.testing.assert_allclose(rcond, rcond2)
329328

330329

331-
@pytest.mark.parametrize("overwrite_a", [True, False])
332-
def test_getrf(overwrite_a):
333-
from scipy.linalg import lu_factor
334-
335-
from pytensor.link.numba.dispatch.slinalg import _getrf
336-
337-
# TODO: Refactor this test to use compare_numba_and_py after we implement lu_factor in pytensor
338-
339-
@numba.njit()
340-
def getrf(x, overwrite_a):
341-
return _getrf(x, overwrite_a=overwrite_a)
342-
343-
x = np.random.normal(size=(5, 5)).astype(floatX)
344-
x = np.asfortranarray(
345-
x
346-
) # x needs to be fortran-contiguous going into getrf for the overwrite option to work
347-
348-
lu, ipiv = lu_factor(x, overwrite_a=False)
349-
LU, IPIV, info = getrf(x, overwrite_a=overwrite_a)
350-
351-
assert info == 0
352-
assert_allclose(LU, lu)
353-
354-
if overwrite_a:
355-
assert_allclose(x, LU)
356-
357-
# TODO: It seems IPIV is 1-indexed in FORTRAN, so we need to subtract 1. I can't find evidence that scipy is doing
358-
# this, though.
359-
assert_allclose(IPIV - 1, ipiv)
360-
361-
362-
@pytest.mark.parametrize("trans", [0, 1])
363-
@pytest.mark.parametrize("overwrite_a", [True, False])
364-
@pytest.mark.parametrize("overwrite_b", [True, False])
365-
@pytest.mark.parametrize("b_shape", [(5,), (5, 3)], ids=["b_1d", "b_2d"])
366-
def test_getrs(trans, overwrite_a, overwrite_b, b_shape):
367-
from scipy.linalg import lu_factor
368-
from scipy.linalg import lu_solve as sp_lu_solve
369-
370-
from pytensor.link.numba.dispatch.slinalg import _getrf, _getrs
371-
372-
# TODO: Refactor this test to use compare_numba_and_py after we implement lu_solve in pytensor
373-
374-
@numba.njit()
375-
def lu_solve(a, b, trans, overwrite_a, overwrite_b):
376-
lu, ipiv, info = _getrf(a, overwrite_a=overwrite_a)
377-
x, info = _getrs(lu, b, ipiv, trans=trans, overwrite_b=overwrite_b)
378-
return x, lu, info
379-
380-
a = np.random.normal(size=(5, 5)).astype(floatX)
381-
b = np.random.normal(size=b_shape).astype(floatX)
382-
383-
# inputs need to be fortran-contiguous going into getrf and getrs for the overwrite option to work
384-
a = np.asfortranarray(a)
385-
b = np.asfortranarray(b)
386-
387-
lu_and_piv = lu_factor(a, overwrite_a=False)
388-
x_sp = sp_lu_solve(lu_and_piv, b, trans, overwrite_b=False)
389-
390-
x, lu, info = lu_solve(
391-
a, b, trans, overwrite_a=overwrite_a, overwrite_b=overwrite_b
392-
)
393-
assert info == 0
394-
if overwrite_a:
395-
assert_allclose(a, lu)
396-
if overwrite_b:
397-
assert_allclose(b, x)
398-
399-
assert_allclose(x, x_sp)
400-
401-
402330
@pytest.mark.filterwarnings(
403331
'ignore:Cannot cache compiled function "numba_funcified_fgraph"'
404332
)

0 commit comments

Comments
 (0)