|
4 | 4 |
|
5 | 5 | import numpy as np |
6 | 6 | import pytest |
7 | | -from numpy.testing import assert_allclose |
8 | 7 |
|
9 | 8 | import pytensor |
10 | 9 | import pytensor.tensor as pt |
@@ -328,77 +327,6 @@ def gecon(x, norm): |
328 | 327 | np.testing.assert_allclose(rcond, rcond2) |
329 | 328 |
|
330 | 329 |
|
331 | | -@pytest.mark.parametrize("overwrite_a", [True, False]) |
332 | | -def test_getrf(overwrite_a): |
333 | | - from scipy.linalg import lu_factor |
334 | | - |
335 | | - from pytensor.link.numba.dispatch.slinalg import _getrf |
336 | | - |
337 | | - # TODO: Refactor this test to use compare_numba_and_py after we implement lu_factor in pytensor |
338 | | - |
339 | | - @numba.njit() |
340 | | - def getrf(x, overwrite_a): |
341 | | - return _getrf(x, overwrite_a=overwrite_a) |
342 | | - |
343 | | - x = np.random.normal(size=(5, 5)).astype(floatX) |
344 | | - x = np.asfortranarray( |
345 | | - x |
346 | | - ) # x needs to be fortran-contiguous going into getrf for the overwrite option to work |
347 | | - |
348 | | - lu, ipiv = lu_factor(x, overwrite_a=False) |
349 | | - LU, IPIV, info = getrf(x, overwrite_a=overwrite_a) |
350 | | - |
351 | | - assert info == 0 |
352 | | - assert_allclose(LU, lu) |
353 | | - |
354 | | - if overwrite_a: |
355 | | - assert_allclose(x, LU) |
356 | | - |
357 | | - # TODO: It seems IPIV is 1-indexed in FORTRAN, so we need to subtract 1. I can't find evidence that scipy is doing |
358 | | - # this, though. |
359 | | - assert_allclose(IPIV - 1, ipiv) |
360 | | - |
361 | | - |
362 | | -@pytest.mark.parametrize("trans", [0, 1]) |
363 | | -@pytest.mark.parametrize("overwrite_a", [True, False]) |
364 | | -@pytest.mark.parametrize("overwrite_b", [True, False]) |
365 | | -@pytest.mark.parametrize("b_shape", [(5,), (5, 3)], ids=["b_1d", "b_2d"]) |
366 | | -def test_getrs(trans, overwrite_a, overwrite_b, b_shape): |
367 | | - from scipy.linalg import lu_factor |
368 | | - from scipy.linalg import lu_solve as sp_lu_solve |
369 | | - |
370 | | - from pytensor.link.numba.dispatch.slinalg import _getrf, _getrs |
371 | | - |
372 | | - # TODO: Refactor this test to use compare_numba_and_py after we implement lu_solve in pytensor |
373 | | - |
374 | | - @numba.njit() |
375 | | - def lu_solve(a, b, trans, overwrite_a, overwrite_b): |
376 | | - lu, ipiv, info = _getrf(a, overwrite_a=overwrite_a) |
377 | | - x, info = _getrs(lu, b, ipiv, trans=trans, overwrite_b=overwrite_b) |
378 | | - return x, lu, info |
379 | | - |
380 | | - a = np.random.normal(size=(5, 5)).astype(floatX) |
381 | | - b = np.random.normal(size=b_shape).astype(floatX) |
382 | | - |
383 | | - # inputs need to be fortran-contiguous going into getrf and getrs for the overwrite option to work |
384 | | - a = np.asfortranarray(a) |
385 | | - b = np.asfortranarray(b) |
386 | | - |
387 | | - lu_and_piv = lu_factor(a, overwrite_a=False) |
388 | | - x_sp = sp_lu_solve(lu_and_piv, b, trans, overwrite_b=False) |
389 | | - |
390 | | - x, lu, info = lu_solve( |
391 | | - a, b, trans, overwrite_a=overwrite_a, overwrite_b=overwrite_b |
392 | | - ) |
393 | | - assert info == 0 |
394 | | - if overwrite_a: |
395 | | - assert_allclose(a, lu) |
396 | | - if overwrite_b: |
397 | | - assert_allclose(b, x) |
398 | | - |
399 | | - assert_allclose(x, x_sp) |
400 | | - |
401 | | - |
402 | 330 | @pytest.mark.filterwarnings( |
403 | 331 | 'ignore:Cannot cache compiled function "numba_funcified_fgraph"' |
404 | 332 | ) |
|
0 commit comments