-
Notifications
You must be signed in to change notification settings - Fork 39
WIP: add compatibility shims for {eig,eigvals} #379
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -19,6 +19,7 @@ | |
| cross = get_xp(np)(_linalg.cross) | ||
| outer = get_xp(np)(_linalg.outer) | ||
| EighResult = _linalg.EighResult | ||
| EigResult = _linalg.EigResult | ||
| QRResult = _linalg.QRResult | ||
| SlogdetResult = _linalg.SlogdetResult | ||
| SVDResult = _linalg.SVDResult | ||
|
|
@@ -97,6 +98,85 @@ def solve(x1: Array, x2: Array, /) -> Array: | |
| return wrap(r.astype(result_t, copy=False)) | ||
|
|
||
|
|
||
| # Unlike numpy.linalg.eig, Array API version always returns complex results | ||
|
|
||
| def eig(x: Array, /) -> tuple[Array, Array]: | ||
| try: | ||
| from numpy.linalg._linalg import ( # type: ignore[attr-defined] | ||
| _assert_stacked_square, | ||
| _assert_finite, | ||
| _commonType, | ||
| _makearray, | ||
| _raise_linalgerror_eigenvalues_nonconvergence, | ||
| isComplexType, | ||
| _complexType, | ||
| ) | ||
| except ImportError: | ||
| from numpy.linalg.linalg import ( # type: ignore[attr-defined] | ||
| _assert_stacked_square, | ||
| _assert_finite, | ||
| _commonType, | ||
| _makearray, | ||
| _raise_linalgerror_eigenvalues_nonconvergence, | ||
| isComplexType, | ||
| _complexType, | ||
| ) | ||
| from numpy.linalg import _umath_linalg | ||
|
|
||
| x, wrap = _makearray(x) | ||
| _assert_stacked_square(x) | ||
| _assert_finite(x) | ||
| t, result_t = _commonType(x) | ||
|
|
||
| signature = 'D->DD' if isComplexType(t) else 'd->DD' | ||
| with np.errstate(call=_raise_linalgerror_eigenvalues_nonconvergence, | ||
| invalid='call', over='ignore', divide='ignore', | ||
| under='ignore'): | ||
| w, vt = _umath_linalg.eig(x, signature=signature) | ||
|
|
||
| result_t = _complexType(result_t) | ||
| vt = vt.astype(result_t, copy=False) | ||
| return EigResult(w.astype(result_t, copy=False), wrap(vt)) | ||
|
|
||
|
|
||
| def eigvals(x: Array, /) -> Array: | ||
| try: | ||
| from numpy.linalg._linalg import ( # type: ignore[attr-defined] | ||
| _assert_stacked_square, | ||
| _assert_finite, | ||
| _commonType, | ||
| _makearray, | ||
| _raise_linalgerror_eigenvalues_nonconvergence, | ||
| isComplexType, | ||
| _complexType, | ||
| ) | ||
| except ImportError: | ||
| from numpy.linalg.linalg import ( # type: ignore[attr-defined] | ||
| _assert_stacked_square, | ||
| _assert_finite, | ||
| _commonType, | ||
| _makearray, | ||
| _raise_linalgerror_eigenvalues_nonconvergence, | ||
| isComplexType, | ||
| _complexType, | ||
| ) | ||
| from numpy.linalg import _umath_linalg | ||
|
|
||
| x, wrap = _makearray(x) | ||
|
||
| _assert_stacked_square(x) | ||
| _assert_finite(x) | ||
| t, result_t = _commonType(x) | ||
|
|
||
| signature = 'D->D' if isComplexType(t) else 'd->D' | ||
| with np.errstate(call=_raise_linalgerror_eigenvalues_nonconvergence, | ||
| invalid='call', over='ignore', divide='ignore', | ||
| under='ignore'): | ||
| w = _umath_linalg.eigvals(x, signature=signature) | ||
|
|
||
| result_t = _complexType(result_t) | ||
| return w.astype(result_t, copy=False) | ||
|
Comment on lines
+103
to
+177
|
||
|
|
||
|
|
||
| # These functions are completely new here. If the library already has them | ||
| # (i.e., numpy 2.0), use the library version instead of our wrapper. | ||
| if hasattr(np.linalg, "vector_norm"): | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -129,6 +129,10 @@ array_api_tests/test_linalg.py::test_matrix_norm | |
| array_api_tests/test_linalg.py::test_qr | ||
| array_api_tests/test_manipulation_functions.py::test_roll | ||
|
|
||
| # 2025.12 support | ||
| array_api_tests/test_linalg.py::test_eig | ||
| array_api_tests/test_linalg.py::test_eigvals | ||
|
|
||
|
Comment on lines
+132
to
+135
|
||
| # Stubs have a comment: (**note**: libraries may return ``NaN`` to match Python behavior.) | ||
| array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity] | ||
| array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity] | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The return type annotation for the eig function indicates it returns a tuple[Array, Array], but the implementation returns an EigResult (a NamedTuple). This inconsistency should be corrected to return EigResult instead of tuple[Array, Array] to match the actual behavior and to be consistent with similar functions like eigh which properly returns EighResult.