|
8 | 8 |
|
9 | 9 | import pytensor |
10 | 10 | import pytensor.tensor as pt |
11 | | -from pytensor import config |
12 | | -from pytensor.tensor.slinalg import SolveTriangular |
| 11 | +from pytensor import In, config |
| 12 | +from pytensor.tensor.slinalg import Solve, SolveTriangular |
13 | 13 | from tests import unittest_tools as utt |
14 | 14 | from tests.link.numba.test_basic import compare_numba_and_py |
15 | 15 |
|
@@ -399,75 +399,98 @@ def lu_solve(a, b, trans, overwrite_a, overwrite_b): |
399 | 399 | assert_allclose(x, x_sp) |
400 | 400 |
|
401 | 401 |
|
| 402 | +@pytest.mark.filterwarnings( |
| 403 | + 'ignore:Cannot cache compiled function "numba_funcified_fgraph"' |
| 404 | +) |
402 | 405 | @pytest.mark.parametrize( |
403 | 406 | "b_shape", |
404 | 407 | [(5, 1), (5, 5), (5,)], |
405 | 408 | ids=["b_col_vec", "b_matrix", "b_vec"], |
406 | 409 | ) |
407 | 410 | @pytest.mark.parametrize("assume_a", ["gen", "sym", "pos"], ids=str) |
408 | | -@pytest.mark.filterwarnings( |
409 | | - 'ignore:Cannot cache compiled function "numba_funcified_fgraph"' |
| 411 | +@pytest.mark.parametrize( |
| 412 | + "overwrite_a, overwrite_b", |
| 413 | + [(False, False), (True, False), (False, True)], |
| 414 | + ids=["no_overwrite", "overwrite_a", "overwrite_b"], |
410 | 415 | ) |
411 | | -def test_solve(b_shape: tuple[int], assume_a: Literal["gen", "sym", "pos"]): |
412 | | - A = pt.matrix("A", dtype=floatX) |
413 | | - b = pt.tensor("b", shape=b_shape, dtype=floatX) |
414 | | - |
415 | | - A_val = np.asfortranarray(np.random.normal(size=(5, 5)).astype(floatX)) |
416 | | - b_val = np.asfortranarray(np.random.normal(size=b_shape).astype(floatX)) |
417 | | - |
| 416 | +def test_solve( |
| 417 | + b_shape: tuple[int], |
| 418 | + assume_a: Literal["gen", "sym", "pos"], |
| 419 | + overwrite_a: bool, |
| 420 | + overwrite_b: bool, |
| 421 | +): |
418 | 422 | def A_func(x): |
419 | 423 | if assume_a == "pos": |
420 | 424 | x = x @ x.T |
421 | 425 | elif assume_a == "sym": |
422 | 426 | x = (x + x.T) / 2 |
423 | 427 | return x |
424 | 428 |
|
| 429 | + A = pt.matrix("A", dtype=floatX) |
| 430 | + b = pt.tensor("b", shape=b_shape, dtype=floatX) |
| 431 | + |
| 432 | + rng = np.random.default_rng(418) |
| 433 | + A_val = np.asfortranarray(A_func(rng.normal(size=(5, 5))).astype(floatX)) |
| 434 | + b_val = np.asfortranarray(rng.normal(size=b_shape).astype(floatX)) |
| 435 | + |
425 | 436 | X = pt.linalg.solve( |
426 | | - A_func(A), |
| 437 | + A, |
427 | 438 | b, |
428 | 439 | assume_a=assume_a, |
429 | 440 | b_ndim=len(b_shape), |
430 | 441 | ) |
431 | | - f = pytensor.function( |
432 | | - [pytensor.In(A, mutable=True), pytensor.In(b, mutable=True)], X, mode="NUMBA" |
433 | | - ) |
434 | | - op = f.maker.fgraph.outputs[0].owner.op |
435 | 442 |
|
436 | | - compare_numba_and_py([A, b], [X], test_inputs=[A_val, b_val], inplace=True) |
437 | | - |
438 | | - # Calling this is destructive and will rewrite b_val to be the answer. Store copies of the inputs first. |
439 | | - A_val_copy = A_val.copy() |
440 | | - b_val_copy = b_val.copy() |
441 | | - |
442 | | - X_np = f(A_val, b_val) |
443 | | - |
444 | | - # overwrite_b is preferred when both inputs can be destroyed |
445 | | - assert op.destroy_map == {0: [1]} |
446 | | - |
447 | | - # Confirm inputs were destroyed by checking against the copies |
448 | | - assert (A_val == A_val_copy).all() == (op.destroy_map.get(0, None) != [0]) |
449 | | - assert (b_val == b_val_copy).all() == (op.destroy_map.get(0, None) != [1]) |
450 | | - |
451 | | - ATOL = 1e-8 if floatX.endswith("64") else 1e-4 |
452 | | - RTOL = 1e-8 if floatX.endswith("64") else 1e-4 |
453 | | - |
454 | | - # Confirm b_val is used to store to solution |
455 | | - np.testing.assert_allclose(X_np, b_val, atol=ATOL, rtol=RTOL) |
456 | | - assert not np.allclose(b_val, b_val_copy) |
457 | | - |
458 | | - # Test that the result is numerically correct. Need to use the unmodified copy |
459 | | - np.testing.assert_allclose( |
460 | | - A_func(A_val_copy) @ X_np, b_val_copy, atol=ATOL, rtol=RTOL |
| 443 | + f, res = compare_numba_and_py( |
| 444 | + [In(A, mutable=overwrite_a), In(b, mutable=overwrite_b)], |
| 445 | + X, |
| 446 | + test_inputs=[A_val, b_val], |
| 447 | + inplace=True, |
| 448 | + numba_mode="NUMBA", # Default numba mode inplace rewrites get triggered |
461 | 449 | ) |
462 | 450 |
|
463 | | - # See the note in tensor/test_slinalg.py::test_solve_correctness for details about the setup here |
464 | | - utt.verify_grad( |
465 | | - lambda A, b: pt.linalg.solve( |
466 | | - A_func(A), b, lower=False, assume_a=assume_a, b_ndim=len(b_shape) |
467 | | - ), |
468 | | - [A_val_copy, b_val_copy], |
469 | | - mode="NUMBA", |
| 451 | + op = f.maker.fgraph.outputs[0].owner.op |
| 452 | + assert isinstance(op, Solve) |
| 453 | + destroy_map = op.destroy_map |
| 454 | + if overwrite_a and overwrite_b: |
| 455 | + raise NotImplementedError( |
| 456 | + "Test not implemented for symultaneous overwrite_a and overwrite_b, as that's not currently supported by PyTensor" |
| 457 | + ) |
| 458 | + elif overwrite_a: |
| 459 | + assert destroy_map == {0: [0]} |
| 460 | + elif overwrite_b: |
| 461 | + assert destroy_map == {0: [1]} |
| 462 | + else: |
| 463 | + assert destroy_map == {} |
| 464 | + |
| 465 | + # Test inputs are destroyed if possible |
| 466 | + A_val_f_contig = np.copy(A_val, order="F") |
| 467 | + b_val_f_contig = np.copy(b_val, order="F") |
| 468 | + res_f_contig = f(A_val_f_contig, b_val_f_contig) |
| 469 | + np.testing.assert_allclose(res_f_contig, res) |
| 470 | + assert (A_val == A_val_f_contig).all() == (not overwrite_a) |
| 471 | + assert (b_val == b_val_f_contig).all() == (not overwrite_b) |
| 472 | + |
| 473 | + # Test right results even if input cannot be destroyed because it is not F-contiguous |
| 474 | + A_val_c_contig = np.copy(A_val, order="C") |
| 475 | + b_val_c_contig = np.copy(b_val, order="C") |
| 476 | + res_c_contig = f(A_val_c_contig, b_val_c_contig) |
| 477 | + np.testing.assert_allclose(res_c_contig, res) |
| 478 | + # We can actually destroy either C or F-contiguous arrays |
| 479 | + assert np.allclose(A_val_c_contig, A_val) == ( |
| 480 | + not (overwrite_a and assume_a in ("sym", "pos")) |
470 | 481 | ) |
| 482 | + # Vectors are always f_contiguous if also c_contiguous |
| 483 | + assert np.allclose(b_val_c_contig, b_val) == ( |
| 484 | + not (overwrite_b and b_val_c_contig.flags.f_contiguous) |
| 485 | + ) |
| 486 | + |
| 487 | + # Test right results if inputs are not contiguous in either format |
| 488 | + A_val_not_contig = np.repeat(A_val, 2, axis=0)[::2] |
| 489 | + b_val_not_contig = np.repeat(b_val, 2, axis=0)[::2] |
| 490 | + res_not_contig = f(A_val_not_contig, b_val_not_contig) |
| 491 | + np.testing.assert_allclose(res_not_contig, res) |
| 492 | + np.testing.assert_allclose(A_val_not_contig, A_val) |
| 493 | + np.testing.assert_allclose(b_val_not_contig, b_val) |
471 | 494 |
|
472 | 495 |
|
473 | 496 | @pytest.mark.parametrize( |
|
0 commit comments