|
8 | 8 |
|
9 | 9 | import pytensor |
10 | 10 | import pytensor.tensor as pt |
11 | | -from pytensor import config |
12 | | -from pytensor.tensor.slinalg import SolveTriangular |
| 11 | +from pytensor import In, config |
| 12 | +from pytensor.tensor import TensorVariable |
| 13 | +from pytensor.tensor.slinalg import Solve, SolveTriangular |
13 | 14 | from tests import unittest_tools as utt |
14 | 15 | from tests.link.numba.test_basic import compare_numba_and_py |
15 | 16 |
|
@@ -408,66 +409,109 @@ def lu_solve(a, b, trans, overwrite_a, overwrite_b): |
408 | 409 | @pytest.mark.filterwarnings( |
409 | 410 | 'ignore:Cannot cache compiled function "numba_funcified_fgraph"' |
410 | 411 | ) |
411 | | -def test_solve(b_shape: tuple[int], assume_a: Literal["gen", "sym", "pos"]): |
412 | | - A = pt.matrix("A", dtype=floatX) |
413 | | - b = pt.tensor("b", shape=b_shape, dtype=floatX) |
414 | | - |
415 | | - A_val = np.asfortranarray(np.random.normal(size=(5, 5)).astype(floatX)) |
416 | | - b_val = np.asfortranarray(np.random.normal(size=b_shape).astype(floatX)) |
417 | | - |
| 412 | +@pytest.mark.parametrize( |
| 413 | + "overwrite_a, overwrite_b", |
| 414 | + [(False, False), (True, False), (False, True)], |
| 415 | + ids=["no_overwrite", "overwrite_a", "overwrite_b"], |
| 416 | +) |
| 417 | +def test_solve( |
| 418 | + b_shape: tuple[int], |
| 419 | + assume_a: Literal["gen", "sym", "pos"], |
| 420 | + overwrite_a: bool, |
| 421 | + overwrite_b: bool, |
| 422 | +): |
418 | 423 | def A_func(x): |
419 | 424 | if assume_a == "pos": |
420 | 425 | x = x @ x.T |
421 | 426 | elif assume_a == "sym": |
422 | 427 | x = (x + x.T) / 2 |
| 428 | + elif assume_a == "tridiagonal": |
| 429 | + lib = pt if isinstance(x, TensorVariable) else np |
| 430 | + diag_fn = getattr(lib, "diag") |
| 431 | + eye_fn = getattr(lib, "eye") |
| 432 | + concatenate_fn = getattr(lib, "concatenate") |
| 433 | + |
| 434 | + ud = diag_fn(x, 1) |
| 435 | + ld = diag_fn(x, -1) |
| 436 | + # Set ud and ld to zeros |
| 437 | + d = (x - diag_fn(ud, 1) - diag_fn(ld, -1)).sum(0) |
| 438 | + return x * ( |
| 439 | + eye_fn(x.shape[1], k=0) * d |
| 440 | + + eye_fn(x.shape[1], k=-1) * concatenate_fn([[0], ld], axis=-1) |
| 441 | + + eye_fn(x.shape[1], k=1) * concatenate_fn([ud, [0]], axis=-1) |
| 442 | + ) |
423 | 443 | return x |
424 | 444 |
|
| 445 | + A = pt.matrix("A", dtype=floatX) |
| 446 | + b = pt.tensor("b", shape=b_shape, dtype=floatX) |
| 447 | + |
| 448 | + rng = np.random.default_rng(418) |
| 449 | + A_val = np.asfortranarray(A_func(rng.normal(size=(5, 5))).astype(floatX)) |
| 450 | + b_val = np.asfortranarray(rng.normal(size=b_shape).astype(floatX)) |
| 451 | + |
425 | 452 | X = pt.linalg.solve( |
426 | | - A_func(A), |
| 453 | + A, |
427 | 454 | b, |
428 | 455 | assume_a=assume_a, |
429 | 456 | b_ndim=len(b_shape), |
430 | 457 | ) |
431 | | - f = pytensor.function( |
432 | | - [pytensor.In(A, mutable=True), pytensor.In(b, mutable=True)], X, mode="NUMBA" |
433 | | - ) |
434 | | - op = f.maker.fgraph.outputs[0].owner.op |
435 | | - |
436 | | - compare_numba_and_py([A, b], [X], test_inputs=[A_val, b_val], inplace=True) |
437 | | - |
438 | | - # Calling this is destructive and will rewrite b_val to be the answer. Store copies of the inputs first. |
439 | | - A_val_copy = A_val.copy() |
440 | | - b_val_copy = b_val.copy() |
441 | 458 |
|
442 | | - X_np = f(A_val, b_val) |
443 | | - |
444 | | - # overwrite_b is preferred when both inputs can be destroyed |
445 | | - assert op.destroy_map == {0: [1]} |
446 | | - |
447 | | - # Confirm inputs were destroyed by checking against the copies |
448 | | - assert (A_val == A_val_copy).all() == (op.destroy_map.get(0, None) != [0]) |
449 | | - assert (b_val == b_val_copy).all() == (op.destroy_map.get(0, None) != [1]) |
450 | | - |
451 | | - ATOL = 1e-8 if floatX.endswith("64") else 1e-4 |
452 | | - RTOL = 1e-8 if floatX.endswith("64") else 1e-4 |
| 459 | + f, res = compare_numba_and_py( |
| 460 | + [In(A, mutable=overwrite_a), In(b, mutable=overwrite_b)], |
| 461 | + X, |
| 462 | + test_inputs=[A_val, b_val], |
| 463 | + inplace=True, |
| 464 | + numba_mode="NUMBA", # Default numba mode inplace rewrites get triggered |
| 465 | + ) |
| 466 | + f.dprint(print_memory_map=True) |
453 | 467 |
|
454 | | - # Confirm b_val is used to store to solution |
455 | | - np.testing.assert_allclose(X_np, b_val, atol=ATOL, rtol=RTOL) |
456 | | - assert not np.allclose(b_val, b_val_copy) |
| 468 | + op = f.maker.fgraph.outputs[0].owner.op |
| 469 | + assert isinstance(op, Solve) |
| 470 | + destroy_map = op.destroy_map |
| 471 | + if overwrite_a and overwrite_b: |
| 472 | + raise NotImplementedError( |
| 473 | + "Test not implemented for symultaneous overwrite_a and overwrite_b, as that's not currently supported by PyTensor" |
| 474 | + ) |
| 475 | + elif overwrite_a: |
| 476 | + assert destroy_map == {0: [0]} |
| 477 | + elif overwrite_b: |
| 478 | + assert destroy_map == {0: [1]} |
| 479 | + else: |
| 480 | + assert destroy_map == {} |
| 481 | + |
| 482 | + # Test inputs are destroyed if possible |
| 483 | + A_val_f_contig = np.copy(A_val, order="F") |
| 484 | + b_val_f_contig = np.copy(b_val, order="F") |
| 485 | + res_f_contig = f(A_val_f_contig, b_val_f_contig) |
| 486 | + np.testing.assert_allclose(res_f_contig, res) |
| 487 | + assert (A_val == A_val_f_contig).all() == (op.destroy_map.get(0, None) != [0]) |
| 488 | + assert (b_val == b_val_f_contig).all() == (op.destroy_map.get(0, None) != [1]) |
| 489 | + |
| 490 | + # Test right results even if input cannot be destroyed because it is not F-contiguous |
| 491 | + A_val_c_contig = np.copy(A_val, order="C") |
| 492 | + b_val_c_contig = np.copy(b_val, order="C") |
| 493 | + res_c_contig = f(A_val_c_contig, b_val_c_contig) |
| 494 | + np.testing.assert_allclose(res_c_contig, res) |
| 495 | + if assume_a == "sym" and overwrite_a: |
| 496 | + # We can actually destroy either C or F-contiguous arrays, since they are equivalent |
| 497 | + assert not np.allclose(A_val_c_contig, A_val) |
| 498 | + else: |
| 499 | + np.testing.assert_allclose(A_val_c_contig, A_val) |
| 500 | + np.testing.assert_allclose(b_val_c_contig, b_val) |
457 | 501 |
|
458 | | - # Test that the result is numerically correct. Need to use the unmodified copy |
459 | | - np.testing.assert_allclose( |
460 | | - A_func(A_val_copy) @ X_np, b_val_copy, atol=ATOL, rtol=RTOL |
| 502 | + # Test right results if inputs are not contiguous in either format |
| 503 | + A_val_not_contig = np.repeat(A_val, 2, axis=0)[::2] |
| 504 | + assert not ( |
| 505 | + A_val_not_contig.flags.c_contiguous or A_val_not_contig.flags.f_contiguous |
461 | 506 | ) |
462 | | - |
463 | | - # See the note in tensor/test_slinalg.py::test_solve_correctness for details about the setup here |
464 | | - utt.verify_grad( |
465 | | - lambda A, b: pt.linalg.solve( |
466 | | - A_func(A), b, lower=False, assume_a=assume_a, b_ndim=len(b_shape) |
467 | | - ), |
468 | | - [A_val_copy, b_val_copy], |
469 | | - mode="NUMBA", |
| 507 | + b_val_not_contig = np.repeat(b_val, 2, axis=0)[::2] |
| 508 | + assert not ( |
| 509 | + b_val_not_contig.flags.c_contiguous or b_val_not_contig.flags.f_contiguous |
470 | 510 | ) |
| 511 | + res_not_contig = f(A_val_not_contig, b_val_not_contig) |
| 512 | + np.testing.assert_allclose(res_not_contig, res) |
| 513 | + np.testing.assert_allclose(A_val_not_contig, A_val) |
| 514 | + np.testing.assert_allclose(b_val_not_contig, b_val) |
471 | 515 |
|
472 | 516 |
|
473 | 517 | @pytest.mark.parametrize( |
|
0 commit comments