Skip to content

Commit 96e63ea

Browse files
committed
jnp.linalg: add symmetrize_input argument & docs
1 parent 74917ce commit 96e63ea

File tree

2 files changed

+19
-8
lines changed

2 files changed

+19
-8
lines changed

jax/_src/numpy/linalg.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,8 @@ def _symmetrize(x: Array) -> Array: return (x + _H(x)) / 2
7272

7373

7474
@export
75-
@partial(jit, static_argnames=['upper'])
76-
def cholesky(a: ArrayLike, *, upper: bool = False) -> Array:
75+
@partial(jit, static_argnames=['upper', 'symmetrize_input'])
76+
def cholesky(a: ArrayLike, *, upper: bool = False, symmetrize_input: bool = True) -> Array:
7777
"""Compute the Cholesky decomposition of a matrix.
7878
7979
JAX implementation of :func:`numpy.linalg.cholesky`.
@@ -98,6 +98,10 @@ def cholesky(a: ArrayLike, *, upper: bool = False) -> Array:
9898
Must have shape ``(..., N, N)``.
9999
upper: if True, compute the upper Cholesky decomposition `U`. if False
100100
(default), compute the lower Cholesky decomposition `L`.
101+
symmetrize_input: if True (default) then input is symmetrized, which leads
102+
to better behavior under automatic differentiation. Note that when this
103+
is set to True, both the upper and lower triangles of the input will
104+
be used in computing the decomposition.
101105
102106
Returns:
103107
array of shape ``(..., N, N)`` representing the Cholesky decomposition
@@ -135,7 +139,7 @@ def cholesky(a: ArrayLike, *, upper: bool = False) -> Array:
135139
"""
136140
a = ensure_arraylike("jnp.linalg.cholesky", a)
137141
a, = promote_dtypes_inexact(a)
138-
L = lax_linalg.cholesky(a)
142+
L = lax_linalg.cholesky(a, symmetrize_input=symmetrize_input)
139143
return L.mT.conj() if upper else L
140144

141145

@@ -821,7 +825,9 @@ def eigh(a: ArrayLike, UPLO: str | None = None,
821825
UPLO: specifies whether the calculation is done with the lower triangular
822826
part of ``a`` (``'L'``, default) or the upper triangular part (``'U'``).
823827
symmetrize_input: if True (default) then input is symmetrized, which leads
824-
to better behavior under automatic differentiation.
828+
to better behavior under automatic differentiation. Note that when this
829+
is set to True, both the upper and lower triangles of the input will
830+
be used in computing the decomposition.
825831
826832
Returns:
827833
A namedtuple ``(eigenvalues, eigenvectors)`` where
@@ -863,8 +869,9 @@ def eigh(a: ArrayLike, UPLO: str | None = None,
863869

864870

865871
@export
866-
@partial(jit, static_argnames=('UPLO',))
867-
def eigvalsh(a: ArrayLike, UPLO: str | None = 'L') -> Array:
872+
@partial(jit, static_argnames=('UPLO', 'symmetrize_input'))
873+
def eigvalsh(a: ArrayLike, UPLO: str | None = 'L', *,
874+
symmetrize_input: bool = True) -> Array:
868875
"""
869876
Compute the eigenvalues of a Hermitian matrix.
870877
@@ -875,6 +882,10 @@ def eigvalsh(a: ArrayLike, UPLO: str | None = 'L') -> Array:
875882
or symmetric (if real) matrix.
876883
UPLO: specifies whether the calculation is done with the lower triangular
877884
part of ``a`` (``'L'``, default) or the upper triangular part (``'U'``).
885+
symmetrize_input: if True (default) then input is symmetrized, which leads
886+
to better behavior under automatic differentiation. Note that when this
887+
is set to True, both the upper and lower triangles of the input will
888+
be used in computing the decomposition.
878889
879890
Returns:
880891
An array of shape ``(..., M)`` containing the eigenvalues, sorted in
@@ -894,7 +905,7 @@ def eigvalsh(a: ArrayLike, UPLO: str | None = 'L') -> Array:
894905
"""
895906
a = ensure_arraylike("jnp.linalg.eigvalsh", a)
896907
a, = promote_dtypes_inexact(a)
897-
w, _ = eigh(a, UPLO)
908+
w, _ = eigh(a, UPLO, symmetrize_input=symmetrize_input)
898909
return w
899910

900911

tests/linalg_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def args_maker():
9696
a = rng(factor_shape, dtype)
9797
return [np.matmul(a, jnp.conj(T(a)))]
9898

99-
jnp_fun = partial(jnp.linalg.cholesky, upper=upper)
99+
jnp_fun = partial(jnp.linalg.cholesky, upper=upper, symmetrize_input=True)
100100

101101
def np_fun(x, upper=upper):
102102
# Upper argument added in NumPy 2.0.0

0 commit comments

Comments
 (0)