|
1 | 1 | import warnings |
2 | 2 | from collections.abc import Callable |
| 3 | +from typing import cast as typing_cast |
3 | 4 |
|
4 | 5 | import numba |
5 | 6 | import numpy as np |
| 7 | +import scipy.linalg |
6 | 8 | from numba.core import types |
7 | 9 | from numba.extending import overload |
8 | 10 | from numba.np.linalg import _copy_to_fortran_order, ensure_lapack |
|
18 | 20 | ) |
19 | 21 | from pytensor.link.numba.dispatch.basic import numba_funcify |
20 | 22 | from pytensor.tensor.slinalg import ( |
| 23 | + LU, |
21 | 24 | BlockDiagonal, |
22 | 25 | Cholesky, |
23 | 26 | CholeskySolve, |
@@ -476,10 +479,11 @@ def impl(A: np.ndarray, A_norm: float, norm: str) -> tuple[np.ndarray, int]: |
476 | 479 | def _getrf(A, overwrite_a=False) -> tuple[np.ndarray, np.ndarray, int]: |
477 | 480 | """ |
478 | 481 | Placeholder for LU factorization; used by linalg.solve. |
479 | | -
|
480 | | - # TODO: Implement an LU_factor Op, then dispatch to this function in numba mode. |
481 | 482 | """ |
482 | | - return # type: ignore |
| 483 | + getrf = scipy.linalg.get_lapack_funcs("getrf", (A,)) |
| 484 | + A_copy, ipiv, info = getrf(A, overwrite_a=overwrite_a) |
| 485 | + |
| 486 | + return A_copy, ipiv |
483 | 487 |
|
484 | 488 |
|
485 | 489 | @overload(_getrf) |
@@ -515,6 +519,263 @@ def impl( |
515 | 519 | return impl |
516 | 520 |
|
517 | 521 |
|
| 522 | +def _lu_1( |
| 523 | + a: np.ndarray, |
| 524 | + permute_l: bool, |
| 525 | + check_finite: bool, |
| 526 | + p_indices: bool, |
| 527 | + overwrite_a: bool, |
| 528 | +) -> tuple[np.ndarray, np.ndarray, np.ndarray]: |
| 529 | + """ |
| 530 | + Thin wrapper around scipy.linalg.lu. Used as an overload target to avoid side-effects on users to import Pytensor. |
| 531 | +
|
| 532 | + Called when permute_l is True and p_indices is False, and returns a tuple of (perm, L, U), where perm an integer |
| 533 | + array of row swaps, such that L[perm] @ U = A. |
| 534 | + """ |
| 535 | + return typing_cast( |
| 536 | + tuple[np.ndarray, np.ndarray, np.ndarray], |
| 537 | + linalg.lu( |
| 538 | + a, |
| 539 | + permute_l=permute_l, |
| 540 | + check_finite=check_finite, |
| 541 | + p_indices=p_indices, |
| 542 | + overwrite_a=overwrite_a, |
| 543 | + ), |
| 544 | + ) |
| 545 | + |
| 546 | + |
| 547 | +def _lu_2( |
| 548 | + a: np.ndarray, |
| 549 | + permute_l: bool, |
| 550 | + check_finite: bool, |
| 551 | + p_indices: bool, |
| 552 | + overwrite_a: bool, |
| 553 | +) -> tuple[np.ndarray, np.ndarray]: |
| 554 | + """ |
| 555 | + Thin wrapper around scipy.linalg.lu. Used as an overload target to avoid side-effects on users to import Pytensor. |
| 556 | +
|
| 557 | + Called when permute_l is False and p_indices is True, and returns a tuple of (PL, U), where PL is the |
| 558 | + permuted L matrix, PL = P @ L. |
| 559 | + """ |
| 560 | + return typing_cast( |
| 561 | + tuple[np.ndarray, np.ndarray], |
| 562 | + linalg.lu( |
| 563 | + a, |
| 564 | + permute_l=permute_l, |
| 565 | + check_finite=check_finite, |
| 566 | + p_indices=p_indices, |
| 567 | + overwrite_a=overwrite_a, |
| 568 | + ), |
| 569 | + ) |
| 570 | + |
| 571 | + |
| 572 | +def _lu_3( |
| 573 | + a: np.ndarray, |
| 574 | + permute_l: bool, |
| 575 | + check_finite: bool, |
| 576 | + p_indices: bool, |
| 577 | + overwrite_a: bool, |
| 578 | +) -> tuple[np.ndarray, np.ndarray, np.ndarray]: |
| 579 | + """ |
| 580 | + Thin wrapper around scipy.linalg.lu. Used as an overload target to avoid side-effects on users to import Pytensor. |
| 581 | +
|
| 582 | + Called when permute_l is False and p_indices is False, and returns a tuple of (P, L, U), where P is the permutation |
| 583 | + matrix, P @ L @ U = A. |
| 584 | + """ |
| 585 | + return typing_cast( |
| 586 | + tuple[np.ndarray, np.ndarray, np.ndarray], |
| 587 | + linalg.lu( |
| 588 | + a, |
| 589 | + permute_l=permute_l, |
| 590 | + check_finite=check_finite, |
| 591 | + p_indices=p_indices, |
| 592 | + overwrite_a=overwrite_a, |
| 593 | + ), |
| 594 | + ) |
| 595 | + |
| 596 | + |
| 597 | +@overload(_lu_1) |
| 598 | +def lu_impl_1( |
| 599 | + a: np.ndarray, |
| 600 | + permute_l: bool, |
| 601 | + check_finite: bool, |
| 602 | + p_indices: bool, |
| 603 | + overwrite_a: bool, |
| 604 | +) -> Callable[ |
| 605 | + [np.ndarray, bool, bool, bool, bool], tuple[np.ndarray, np.ndarray, np.ndarray] |
| 606 | +]: |
| 607 | + """ |
| 608 | + Overload scipy.linalg.lu with a numba function. This function is called when permute_l is True and p_indices is |
| 609 | + False. Returns a tuple of (perm, L, U), where perm an integer array of row swaps, such that L[perm] @ U = A. |
| 610 | + """ |
| 611 | + ensure_lapack() |
| 612 | + _check_scipy_linalg_matrix(a, "lu") |
| 613 | + dtype = a.dtype |
| 614 | + |
| 615 | + def impl( |
| 616 | + a: np.ndarray, |
| 617 | + permute_l: bool, |
| 618 | + check_finite: bool, |
| 619 | + p_indices: bool, |
| 620 | + overwrite_a: bool, |
| 621 | + ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: |
| 622 | + A_copy, IPIV, INFO = _getrf(a, overwrite_a=overwrite_a) |
| 623 | + |
| 624 | + L = np.eye(A_copy.shape[-1], dtype=dtype) |
| 625 | + L += np.tril(A_copy, k=-1) |
| 626 | + U = np.triu(A_copy) |
| 627 | + |
| 628 | + # Fortran is 1 indexed, so we need to subtract 1 from the IPIV array |
| 629 | + IPIV = IPIV - 1 |
| 630 | + p_inv = np.arange(len(IPIV)) |
| 631 | + for i in range(len(IPIV)): |
| 632 | + p_inv[i], p_inv[IPIV[i]] = p_inv[IPIV[i]], p_inv[i] |
| 633 | + |
| 634 | + perm = np.argsort(p_inv) |
| 635 | + return perm, L, U |
| 636 | + |
| 637 | + return impl |
| 638 | + |
| 639 | + |
| 640 | +@overload(_lu_2) |
| 641 | +def lu_impl_2( |
| 642 | + a: np.ndarray, |
| 643 | + permute_l: bool, |
| 644 | + check_finite: bool, |
| 645 | + p_indices: bool, |
| 646 | + overwrite_a: bool, |
| 647 | +) -> Callable[[np.ndarray, bool, bool, bool, bool], tuple[np.ndarray, np.ndarray]]: |
| 648 | + """ |
| 649 | + Overload scipy.linalg.lu with a numba function. This function is called when permute_l is False and p_indices is |
| 650 | + True. Returns a tuple of (PL, U), where PL is the permuted L matrix, PL = P @ L. |
| 651 | + """ |
| 652 | + |
| 653 | + ensure_lapack() |
| 654 | + _check_scipy_linalg_matrix(a, "lu") |
| 655 | + dtype = a.dtype |
| 656 | + |
| 657 | + def impl( |
| 658 | + a: np.ndarray, |
| 659 | + permute_l: bool, |
| 660 | + check_finite: bool, |
| 661 | + p_indices: bool, |
| 662 | + overwrite_a: bool, |
| 663 | + ) -> tuple[np.ndarray, np.ndarray]: |
| 664 | + A_copy, IPIV, INFO = _getrf(a, overwrite_a=overwrite_a) |
| 665 | + |
| 666 | + L = np.eye(A_copy.shape[-1], dtype=dtype) |
| 667 | + L += np.tril(A_copy, k=-1) |
| 668 | + U = np.triu(A_copy) |
| 669 | + |
| 670 | + # Fortran is 1 indexed, so we need to subtract 1 from the IPIV array |
| 671 | + IPIV = IPIV - 1 |
| 672 | + p_inv = np.arange(len(IPIV)) |
| 673 | + for i in range(len(IPIV)): |
| 674 | + p_inv[i], p_inv[IPIV[i]] = p_inv[IPIV[i]], p_inv[i] |
| 675 | + |
| 676 | + perm = np.argsort(p_inv) |
| 677 | + PL = L[perm] |
| 678 | + return PL, U |
| 679 | + |
| 680 | + return impl |
| 681 | + |
| 682 | + |
| 683 | +@overload(_lu_3) |
| 684 | +def lu_impl_3( |
| 685 | + a: np.ndarray, |
| 686 | + permute_l: bool, |
| 687 | + check_finite: bool, |
| 688 | + p_indices: bool, |
| 689 | + overwrite_a: bool, |
| 690 | +) -> Callable[ |
| 691 | + [np.ndarray, bool, bool, bool, bool], tuple[np.ndarray, np.ndarray, np.ndarray] |
| 692 | +]: |
| 693 | + """ |
| 694 | + Overload scipy.linalg.lu with a numba function. This function is called when permute_l is True and p_indices is |
| 695 | + False. Returns a tuple of (P, L, U), such that P @ L @ U = A. |
| 696 | + """ |
| 697 | + ensure_lapack() |
| 698 | + _check_scipy_linalg_matrix(a, "lu") |
| 699 | + dtype = a.dtype |
| 700 | + |
| 701 | + def impl( |
| 702 | + a: np.ndarray, |
| 703 | + permute_l: bool, |
| 704 | + check_finite: bool, |
| 705 | + p_indices: bool, |
| 706 | + overwrite_a: bool, |
| 707 | + ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: |
| 708 | + A_copy, IPIV, INFO = _getrf(a, overwrite_a=overwrite_a) |
| 709 | + |
| 710 | + L = np.eye(A_copy.shape[-1], dtype=dtype) |
| 711 | + L += np.tril(A_copy, k=-1) |
| 712 | + U = np.triu(A_copy) |
| 713 | + |
| 714 | + # Fortran is 1 indexed, so we need to subtract 1 from the IPIV array |
| 715 | + IPIV = IPIV - 1 |
| 716 | + p_inv = np.arange(len(IPIV)) |
| 717 | + for i in range(len(IPIV)): |
| 718 | + p_inv[i], p_inv[IPIV[i]] = p_inv[IPIV[i]], p_inv[i] |
| 719 | + |
| 720 | + perm = np.argsort(p_inv) |
| 721 | + P = np.eye(A_copy.shape[-1], dtype=dtype)[perm] |
| 722 | + |
| 723 | + return P, L, U |
| 724 | + |
| 725 | + return impl |
| 726 | + |
| 727 | + |
| 728 | +@numba_funcify.register(LU) |
| 729 | +def numba_funcify_LU(op, node, **kwargs): |
| 730 | + permute_l = op.permute_l |
| 731 | + check_finite = op.check_finite |
| 732 | + p_indices = op.p_indices |
| 733 | + overwrite_a = op.overwrite_a |
| 734 | + |
| 735 | + dtype = node.inputs[0].dtype |
| 736 | + if str(dtype).startswith("complex"): |
| 737 | + raise NotImplementedError( |
| 738 | + "Complex inputs not currently supported by lu in Numba mode" |
| 739 | + ) |
| 740 | + |
| 741 | + @numba_basic.numba_njit(inline="always") |
| 742 | + def lu(a): |
| 743 | + if check_finite: |
| 744 | + if np.any(np.bitwise_or(np.isinf(a), np.isnan(a))): |
| 745 | + raise np.linalg.LinAlgError( |
| 746 | + "Non-numeric values (nan or inf) found in input to lu" |
| 747 | + ) |
| 748 | + |
| 749 | + if p_indices: |
| 750 | + res = _lu_1( |
| 751 | + a, |
| 752 | + permute_l=permute_l, |
| 753 | + check_finite=check_finite, |
| 754 | + p_indices=p_indices, |
| 755 | + overwrite_a=overwrite_a, |
| 756 | + ) |
| 757 | + elif permute_l: |
| 758 | + res = _lu_2( |
| 759 | + a, |
| 760 | + permute_l=permute_l, |
| 761 | + check_finite=check_finite, |
| 762 | + p_indices=p_indices, |
| 763 | + overwrite_a=overwrite_a, |
| 764 | + ) |
| 765 | + else: |
| 766 | + res = _lu_3( |
| 767 | + a, |
| 768 | + permute_l=permute_l, |
| 769 | + check_finite=check_finite, |
| 770 | + p_indices=p_indices, |
| 771 | + overwrite_a=overwrite_a, |
| 772 | + ) |
| 773 | + |
| 774 | + return res |
| 775 | + |
| 776 | + return lu |
| 777 | + |
| 778 | + |
518 | 779 | def _getrs( |
519 | 780 | LU: np.ndarray, B: np.ndarray, IPIV: np.ndarray, trans: int, overwrite_b: bool |
520 | 781 | ) -> tuple[np.ndarray, int]: |
|
0 commit comments