|
1 | 1 | import warnings |
2 | 2 | from collections.abc import Callable |
| 3 | +from typing import cast as typing_cast |
3 | 4 |
|
4 | 5 | import numba |
5 | 6 | import numpy as np |
| 7 | +import scipy.linalg |
6 | 8 | from numba.core import types |
7 | 9 | from numba.extending import overload |
8 | 10 | from numba.np.linalg import _copy_to_fortran_order, ensure_lapack |
|
18 | 20 | ) |
19 | 21 | from pytensor.link.numba.dispatch.basic import numba_funcify |
20 | 22 | from pytensor.tensor.slinalg import ( |
| 23 | + LU, |
21 | 24 | BlockDiagonal, |
22 | 25 | Cholesky, |
23 | 26 | CholeskySolve, |
@@ -492,10 +495,11 @@ def impl(A: np.ndarray, A_norm: float, norm: str) -> tuple[np.ndarray, int]: |
492 | 495 | def _getrf(A, overwrite_a=False) -> tuple[np.ndarray, np.ndarray, int]: |
493 | 496 | """ |
494 | 497 | Placeholder for LU factorization; used by linalg.solve. |
495 | | -
|
496 | | - # TODO: Implement an LU_factor Op, then dispatch to this function in numba mode. |
497 | 498 | """ |
498 | | - return # type: ignore |
| 499 | + getrf = scipy.linalg.get_lapack_funcs("getrf", (A,)) |
| 500 | + A_copy, ipiv, info = getrf(A, overwrite_a=overwrite_a) |
| 501 | + |
| 502 | + return A_copy, ipiv |
499 | 503 |
|
500 | 504 |
|
501 | 505 | @overload(_getrf) |
@@ -531,6 +535,263 @@ def impl( |
531 | 535 | return impl |
532 | 536 |
|
533 | 537 |
|
| 538 | +def _lu_1( |
| 539 | + a: np.ndarray, |
| 540 | + permute_l: bool, |
| 541 | + check_finite: bool, |
| 542 | + p_indices: bool, |
| 543 | + overwrite_a: bool, |
| 544 | +) -> tuple[np.ndarray, np.ndarray, np.ndarray]: |
| 545 | + """ |
| 546 | + Thin wrapper around scipy.linalg.lu. Used as an overload target to avoid side-effects on users to import Pytensor. |
| 547 | +
|
| 548 | + Called when permute_l is True and p_indices is False, and returns a tuple of (perm, L, U), where perm an integer |
| 549 | + array of row swaps, such that L[perm] @ U = A. |
| 550 | + """ |
| 551 | + return typing_cast( |
| 552 | + tuple[np.ndarray, np.ndarray, np.ndarray], |
| 553 | + linalg.lu( |
| 554 | + a, |
| 555 | + permute_l=permute_l, |
| 556 | + check_finite=check_finite, |
| 557 | + p_indices=p_indices, |
| 558 | + overwrite_a=overwrite_a, |
| 559 | + ), |
| 560 | + ) |
| 561 | + |
| 562 | + |
| 563 | +def _lu_2( |
| 564 | + a: np.ndarray, |
| 565 | + permute_l: bool, |
| 566 | + check_finite: bool, |
| 567 | + p_indices: bool, |
| 568 | + overwrite_a: bool, |
| 569 | +) -> tuple[np.ndarray, np.ndarray]: |
| 570 | + """ |
| 571 | + Thin wrapper around scipy.linalg.lu. Used as an overload target to avoid side-effects on users to import Pytensor. |
| 572 | +
|
| 573 | + Called when permute_l is False and p_indices is True, and returns a tuple of (PL, U), where PL is the |
| 574 | + permuted L matrix, PL = P @ L. |
| 575 | + """ |
| 576 | + return typing_cast( |
| 577 | + tuple[np.ndarray, np.ndarray], |
| 578 | + linalg.lu( |
| 579 | + a, |
| 580 | + permute_l=permute_l, |
| 581 | + check_finite=check_finite, |
| 582 | + p_indices=p_indices, |
| 583 | + overwrite_a=overwrite_a, |
| 584 | + ), |
| 585 | + ) |
| 586 | + |
| 587 | + |
| 588 | +def _lu_3( |
| 589 | + a: np.ndarray, |
| 590 | + permute_l: bool, |
| 591 | + check_finite: bool, |
| 592 | + p_indices: bool, |
| 593 | + overwrite_a: bool, |
| 594 | +) -> tuple[np.ndarray, np.ndarray, np.ndarray]: |
| 595 | + """ |
| 596 | + Thin wrapper around scipy.linalg.lu. Used as an overload target to avoid side-effects on users to import Pytensor. |
| 597 | +
|
| 598 | + Called when permute_l is False and p_indices is False, and returns a tuple of (P, L, U), where P is the permutation |
| 599 | + matrix, P @ L @ U = A. |
| 600 | + """ |
| 601 | + return typing_cast( |
| 602 | + tuple[np.ndarray, np.ndarray, np.ndarray], |
| 603 | + linalg.lu( |
| 604 | + a, |
| 605 | + permute_l=permute_l, |
| 606 | + check_finite=check_finite, |
| 607 | + p_indices=p_indices, |
| 608 | + overwrite_a=overwrite_a, |
| 609 | + ), |
| 610 | + ) |
| 611 | + |
| 612 | + |
| 613 | +@overload(_lu_1) |
| 614 | +def lu_impl_1( |
| 615 | + a: np.ndarray, |
| 616 | + permute_l: bool, |
| 617 | + check_finite: bool, |
| 618 | + p_indices: bool, |
| 619 | + overwrite_a: bool, |
| 620 | +) -> Callable[ |
| 621 | + [np.ndarray, bool, bool, bool, bool], tuple[np.ndarray, np.ndarray, np.ndarray] |
| 622 | +]: |
| 623 | + """ |
| 624 | + Overload scipy.linalg.lu with a numba function. This function is called when permute_l is True and p_indices is |
| 625 | + False. Returns a tuple of (perm, L, U), where perm an integer array of row swaps, such that L[perm] @ U = A. |
| 626 | + """ |
| 627 | + ensure_lapack() |
| 628 | + _check_scipy_linalg_matrix(a, "lu") |
| 629 | + dtype = a.dtype |
| 630 | + |
| 631 | + def impl( |
| 632 | + a: np.ndarray, |
| 633 | + permute_l: bool, |
| 634 | + check_finite: bool, |
| 635 | + p_indices: bool, |
| 636 | + overwrite_a: bool, |
| 637 | + ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: |
| 638 | + A_copy, IPIV, INFO = _getrf(a, overwrite_a=overwrite_a) |
| 639 | + |
| 640 | + L = np.eye(A_copy.shape[-1], dtype=dtype) |
| 641 | + L += np.tril(A_copy, k=-1) |
| 642 | + U = np.triu(A_copy) |
| 643 | + |
| 644 | + # Fortran is 1 indexed, so we need to subtract 1 from the IPIV array |
| 645 | + IPIV = IPIV - 1 |
| 646 | + p_inv = np.arange(len(IPIV)) |
| 647 | + for i in range(len(IPIV)): |
| 648 | + p_inv[i], p_inv[IPIV[i]] = p_inv[IPIV[i]], p_inv[i] |
| 649 | + |
| 650 | + perm = np.argsort(p_inv) |
| 651 | + return perm, L, U |
| 652 | + |
| 653 | + return impl |
| 654 | + |
| 655 | + |
| 656 | +@overload(_lu_2) |
| 657 | +def lu_impl_2( |
| 658 | + a: np.ndarray, |
| 659 | + permute_l: bool, |
| 660 | + check_finite: bool, |
| 661 | + p_indices: bool, |
| 662 | + overwrite_a: bool, |
| 663 | +) -> Callable[[np.ndarray, bool, bool, bool, bool], tuple[np.ndarray, np.ndarray]]: |
| 664 | + """ |
| 665 | + Overload scipy.linalg.lu with a numba function. This function is called when permute_l is False and p_indices is |
| 666 | + True. Returns a tuple of (PL, U), where PL is the permuted L matrix, PL = P @ L. |
| 667 | + """ |
| 668 | + |
| 669 | + ensure_lapack() |
| 670 | + _check_scipy_linalg_matrix(a, "lu") |
| 671 | + dtype = a.dtype |
| 672 | + |
| 673 | + def impl( |
| 674 | + a: np.ndarray, |
| 675 | + permute_l: bool, |
| 676 | + check_finite: bool, |
| 677 | + p_indices: bool, |
| 678 | + overwrite_a: bool, |
| 679 | + ) -> tuple[np.ndarray, np.ndarray]: |
| 680 | + A_copy, IPIV, INFO = _getrf(a, overwrite_a=overwrite_a) |
| 681 | + |
| 682 | + L = np.eye(A_copy.shape[-1], dtype=dtype) |
| 683 | + L += np.tril(A_copy, k=-1) |
| 684 | + U = np.triu(A_copy) |
| 685 | + |
| 686 | + # Fortran is 1 indexed, so we need to subtract 1 from the IPIV array |
| 687 | + IPIV = IPIV - 1 |
| 688 | + p_inv = np.arange(len(IPIV)) |
| 689 | + for i in range(len(IPIV)): |
| 690 | + p_inv[i], p_inv[IPIV[i]] = p_inv[IPIV[i]], p_inv[i] |
| 691 | + |
| 692 | + perm = np.argsort(p_inv) |
| 693 | + PL = L[perm] |
| 694 | + return PL, U |
| 695 | + |
| 696 | + return impl |
| 697 | + |
| 698 | + |
| 699 | +@overload(_lu_3) |
| 700 | +def lu_impl_3( |
| 701 | + a: np.ndarray, |
| 702 | + permute_l: bool, |
| 703 | + check_finite: bool, |
| 704 | + p_indices: bool, |
| 705 | + overwrite_a: bool, |
| 706 | +) -> Callable[ |
| 707 | + [np.ndarray, bool, bool, bool, bool], tuple[np.ndarray, np.ndarray, np.ndarray] |
| 708 | +]: |
| 709 | + """ |
| 710 | + Overload scipy.linalg.lu with a numba function. This function is called when permute_l is True and p_indices is |
| 711 | + False. Returns a tuple of (P, L, U), such that P @ L @ U = A. |
| 712 | + """ |
| 713 | + ensure_lapack() |
| 714 | + _check_scipy_linalg_matrix(a, "lu") |
| 715 | + dtype = a.dtype |
| 716 | + |
| 717 | + def impl( |
| 718 | + a: np.ndarray, |
| 719 | + permute_l: bool, |
| 720 | + check_finite: bool, |
| 721 | + p_indices: bool, |
| 722 | + overwrite_a: bool, |
| 723 | + ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: |
| 724 | + A_copy, IPIV, INFO = _getrf(a, overwrite_a=overwrite_a) |
| 725 | + |
| 726 | + L = np.eye(A_copy.shape[-1], dtype=dtype) |
| 727 | + L += np.tril(A_copy, k=-1) |
| 728 | + U = np.triu(A_copy) |
| 729 | + |
| 730 | + # Fortran is 1 indexed, so we need to subtract 1 from the IPIV array |
| 731 | + IPIV = IPIV - 1 |
| 732 | + p_inv = np.arange(len(IPIV)) |
| 733 | + for i in range(len(IPIV)): |
| 734 | + p_inv[i], p_inv[IPIV[i]] = p_inv[IPIV[i]], p_inv[i] |
| 735 | + |
| 736 | + perm = np.argsort(p_inv) |
| 737 | + P = np.eye(A_copy.shape[-1], dtype=dtype)[perm] |
| 738 | + |
| 739 | + return P, L, U |
| 740 | + |
| 741 | + return impl |
| 742 | + |
| 743 | + |
| 744 | +@numba_funcify.register(LU) |
| 745 | +def numba_funcify_LU(op, node, **kwargs): |
| 746 | + permute_l = op.permute_l |
| 747 | + check_finite = op.check_finite |
| 748 | + p_indices = op.p_indices |
| 749 | + overwrite_a = op.overwrite_a |
| 750 | + |
| 751 | + dtype = node.inputs[0].dtype |
| 752 | + if str(dtype).startswith("complex"): |
| 753 | + raise NotImplementedError( |
| 754 | + "Complex inputs not currently supported by lu in Numba mode" |
| 755 | + ) |
| 756 | + |
| 757 | + @numba_basic.numba_njit(inline="always") |
| 758 | + def lu(a): |
| 759 | + if check_finite: |
| 760 | + if np.any(np.bitwise_or(np.isinf(a), np.isnan(a))): |
| 761 | + raise np.linalg.LinAlgError( |
| 762 | + "Non-numeric values (nan or inf) found in input to lu" |
| 763 | + ) |
| 764 | + |
| 765 | + if p_indices: |
| 766 | + res = _lu_1( |
| 767 | + a, |
| 768 | + permute_l=permute_l, |
| 769 | + check_finite=check_finite, |
| 770 | + p_indices=p_indices, |
| 771 | + overwrite_a=overwrite_a, |
| 772 | + ) |
| 773 | + elif permute_l: |
| 774 | + res = _lu_2( |
| 775 | + a, |
| 776 | + permute_l=permute_l, |
| 777 | + check_finite=check_finite, |
| 778 | + p_indices=p_indices, |
| 779 | + overwrite_a=overwrite_a, |
| 780 | + ) |
| 781 | + else: |
| 782 | + res = _lu_3( |
| 783 | + a, |
| 784 | + permute_l=permute_l, |
| 785 | + check_finite=check_finite, |
| 786 | + p_indices=p_indices, |
| 787 | + overwrite_a=overwrite_a, |
| 788 | + ) |
| 789 | + |
| 790 | + return res |
| 791 | + |
| 792 | + return lu |
| 793 | + |
| 794 | + |
534 | 795 | def _getrs( |
535 | 796 | LU: np.ndarray, B: np.ndarray, IPIV: np.ndarray, trans: int, overwrite_b: bool |
536 | 797 | ) -> tuple[np.ndarray, int]: |
|
0 commit comments