@@ -72,8 +72,8 @@ def _symmetrize(x: Array) -> Array: return (x + _H(x)) / 2
7272
7373
7474@export
75- @partial (jit , static_argnames = ['upper' ])
76- def cholesky (a : ArrayLike , * , upper : bool = False ) -> Array :
75+ @partial (jit , static_argnames = ['upper' , 'symmetrize_input' ])
76+ def cholesky (a : ArrayLike , * , upper : bool = False , symmetrize_input : bool = True ) -> Array :
7777 """Compute the Cholesky decomposition of a matrix.
7878
7979 JAX implementation of :func:`numpy.linalg.cholesky`.
@@ -98,6 +98,10 @@ def cholesky(a: ArrayLike, *, upper: bool = False) -> Array:
9898 Must have shape ``(..., N, N)``.
9999 upper: if True, compute the upper Cholesky decomposition `U`. if False
100100 (default), compute the lower Cholesky decomposition `L`.
101+ symmetrize_input: if True (default) then input is symmetrized, which leads
102+ to better behavior under automatic differentiation. Note that when this
103+ is set to True, both the upper and lower triangles of the input will
104+ be used in computing the decomposition.
101105
102106 Returns:
103107 array of shape ``(..., N, N)`` representing the Cholesky decomposition
@@ -135,7 +139,7 @@ def cholesky(a: ArrayLike, *, upper: bool = False) -> Array:
135139 """
136140 a = ensure_arraylike ("jnp.linalg.cholesky" , a )
137141 a , = promote_dtypes_inexact (a )
138- L = lax_linalg .cholesky (a )
142+ L = lax_linalg .cholesky (a , symmetrize_input = symmetrize_input )
139143 return L .mT .conj () if upper else L
140144
141145
@@ -821,7 +825,9 @@ def eigh(a: ArrayLike, UPLO: str | None = None,
821825 UPLO: specifies whether the calculation is done with the lower triangular
822826 part of ``a`` (``'L'``, default) or the upper triangular part (``'U'``).
823827 symmetrize_input: if True (default) then input is symmetrized, which leads
824- to better behavior under automatic differentiation.
828+ to better behavior under automatic differentiation. Note that when this
829+ is set to True, both the upper and lower triangles of the input will
830+ be used in computing the decomposition.
825831
826832 Returns:
827833 A namedtuple ``(eigenvalues, eigenvectors)`` where
@@ -863,8 +869,9 @@ def eigh(a: ArrayLike, UPLO: str | None = None,
863869
864870
865871@export
866- @partial (jit , static_argnames = ('UPLO' ,))
867- def eigvalsh (a : ArrayLike , UPLO : str | None = 'L' ) -> Array :
872+ @partial (jit , static_argnames = ('UPLO' , 'symmetrize_input' ))
873+ def eigvalsh (a : ArrayLike , UPLO : str | None = 'L' , * ,
874+ symmetrize_input : bool = True ) -> Array :
868875 """
869876 Compute the eigenvalues of a Hermitian matrix.
870877
@@ -875,6 +882,10 @@ def eigvalsh(a: ArrayLike, UPLO: str | None = 'L') -> Array:
875882 or symmetric (if real) matrix.
876883 UPLO: specifies whether the calculation is done with the lower triangular
877884 part of ``a`` (``'L'``, default) or the upper triangular part (``'U'``).
885+ symmetrize_input: if True (default) then input is symmetrized, which leads
886+ to better behavior under automatic differentiation. Note that when this
887+ is set to True, both the upper and lower triangles of the input will
888+ be used in computing the decomposition.
878889
879890 Returns:
880891 An array of shape ``(..., M)`` containing the eigenvalues, sorted in
@@ -894,7 +905,7 @@ def eigvalsh(a: ArrayLike, UPLO: str | None = 'L') -> Array:
894905 """
895906 a = ensure_arraylike ("jnp.linalg.eigvalsh" , a )
896907 a , = promote_dtypes_inexact (a )
897- w , _ = eigh (a , UPLO )
908+ w , _ = eigh (a , UPLO , symmetrize_input = symmetrize_input )
898909 return w
899910
900911
0 commit comments