Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 62 additions & 1 deletion array_api_strict/_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from ._data_type_functions import finfo
from ._dtypes import DType, _floating_dtypes, _numeric_dtypes, complex64, complex128
from ._elementwise_functions import conj
from ._flags import get_array_api_strict_flags, requires_extension
from ._flags import get_array_api_strict_flags, requires_extension, requires_api_version
from ._manipulation_functions import reshape
from ._statistical_functions import _np_dtype_sumprod

Expand All @@ -23,6 +23,10 @@ class EighResult(NamedTuple):
eigenvalues: Array
eigenvectors: Array

class EigResult(NamedTuple):
eigenvalues: Array
eigenvectors: Array

class QRResult(NamedTuple):
Q: Array
R: Array
Expand Down Expand Up @@ -144,6 +148,63 @@ def eigvalsh(x: Array, /) -> Array:

return Array._new(np.linalg.eigvalsh(x._array), device=x.device)

@requires_extension('linalg')
@requires_api_version('2025.12')
def eigvals(x: Array, /) -> Array:
"""
Array API compatible wrapper for :py:func:`np.linalg.eigvals <numpy.linalg.eigvals>`.

See its docstring for more information.
"""
# Note: the restriction to floating-point dtypes only is different from
# np.linalg.eigvals.
if x.dtype not in _floating_dtypes:
raise TypeError('Only floating-point dtypes are allowed in eigvals')

res = np.linalg.eigvals(x._array)

# numpy return reals for real inputs
res_dtype = res.dtype
if res.dtype == np.float32:
res_dtype = np.complex64
elif res.dtype == np.float64:
res_dtype = np.complex128

if res_dtype != res.dtype:
res = res.astype(res_dtype)

return Array._new(res, device=x.device)


@requires_extension('linalg')
@requires_api_version('2025.12')
def eig(x: Array, /) -> EigResult:
"""
Array API compatible wrapper for :py:func:`np.linalg.eig <numpy.linalg.eig>`.

See its docstring for more information.
"""
# Note: the restriction to floating-point dtypes only is different from
# np.linalg.eig.
if x.dtype not in _floating_dtypes:
raise TypeError('Only floating-point dtypes are allowed in eig')

w, vr = np.linalg.eig(x._array)

# numpy return reals for real inputs
res_dtype = w.dtype
if w.dtype == np.float32:
res_dtype = np.complex64
elif w.dtype == np.float64:
res_dtype = np.complex128

if res_dtype != w.dtype:
w = w.astype(res_dtype)
vr = vr.astype(res_dtype)

return EigResult(Array._new(w, device=x.device), Array._new(vr, device=x.device))


@requires_extension('linalg')
def inv(x: Array, /) -> Array:
"""
Expand Down
Loading