3535from jax ._src .numpy import lax_numpy as jnp
3636from jax ._src .numpy import reductions , ufuncs
3737from jax ._src .numpy .util import promote_dtypes_inexact , check_arraylike
38- from jax ._src .util import canonicalize_axis
38+ from jax ._src .util import canonicalize_axis , set_module
3939from jax ._src .typing import ArrayLike , Array , DTypeLike , DeprecatedArg
4040
4141
42+ export = set_module ('jax.numpy.linalg' )
43+
44+
4245class EighResult (NamedTuple ):
4346 eigenvalues : jax .Array
4447 eigenvectors : jax .Array
@@ -67,6 +70,7 @@ def _H(x: ArrayLike) -> Array:
6770def _symmetrize (x : Array ) -> Array : return (x + _H (x )) / 2
6871
6972
73+ @export
7074@partial (jit , static_argnames = ['upper' ])
7175def cholesky (a : ArrayLike , * , upper : bool = False ) -> Array :
7276 """Compute the Cholesky decomposition of a matrix.
@@ -191,6 +195,7 @@ def svd(
191195 ...
192196
193197
198+ @export
194199@partial (
195200 jit ,
196201 static_argnames = (
@@ -311,6 +316,7 @@ def svd(
311316 )
312317
313318
319+ @export
314320@partial (jit , static_argnames = ('n' ,))
315321def matrix_power (a : ArrayLike , n : int ) -> Array :
316322 """Raise a square matrix to an integer power.
@@ -392,6 +398,7 @@ def matrix_power(a: ArrayLike, n: int) -> Array:
392398 return result
393399
394400
401+ @export
395402@jit
396403def matrix_rank (
397404 M : ArrayLike , rtol : ArrayLike | None = None , * ,
@@ -496,6 +503,7 @@ def _slogdet_qr(a: Array) -> tuple[Array, Array]:
496503 return sign_diag * sign_taus , log_abs_det
497504
498505
506+ @export
499507@partial (jit , static_argnames = ('method' ,))
500508def slogdet (a : ArrayLike , * , method : str | None = None ) -> SlogdetResult :
501509 """
@@ -675,6 +683,7 @@ def _det_jvp(primals, tangents):
675683 return y , jnp .trace (z , axis1 = - 1 , axis2 = - 2 )
676684
677685
686+ @export
678687@jit
679688def det (a : ArrayLike ) -> Array :
680689 """
@@ -711,6 +720,7 @@ def det(a: ArrayLike) -> Array:
711720 raise ValueError (msg .format (a_shape ))
712721
713722
723+ @export
714724def eig (a : ArrayLike ) -> tuple [Array , Array ]:
715725 """
716726 Compute the eigenvalues and eigenvectors of a square array.
@@ -756,6 +766,7 @@ def eig(a: ArrayLike) -> tuple[Array, Array]:
756766 return w , v
757767
758768
769+ @export
759770@jit
760771def eigvals (a : ArrayLike ) -> Array :
761772 """
@@ -793,6 +804,7 @@ def eigvals(a: ArrayLike) -> Array:
793804 compute_right_eigenvectors = False )[0 ]
794805
795806
807+ @export
796808@partial (jit , static_argnames = ('UPLO' , 'symmetrize_input' ))
797809def eigh (a : ArrayLike , UPLO : str | None = None ,
798810 symmetrize_input : bool = True ) -> EighResult :
@@ -848,6 +860,7 @@ def eigh(a: ArrayLike, UPLO: str | None = None,
848860 return EighResult (w , v )
849861
850862
863+ @export
851864@partial (jit , static_argnames = ('UPLO' ,))
852865def eigvalsh (a : ArrayLike , UPLO : str | None = 'L' ) -> Array :
853866 """
@@ -884,6 +897,7 @@ def eigvalsh(a: ArrayLike, UPLO: str | None = 'L') -> Array:
884897
885898
886899# TODO(micky774): deprecated 2024-5-14, remove wrapper after deprecation expires.
900+ @export
887901def pinv (a : ArrayLike , rtol : ArrayLike | None = None ,
888902 hermitian : bool = False , * ,
889903 rcond : ArrayLike | DeprecatedArg | None = DeprecatedArg ()) -> Array :
@@ -997,6 +1011,7 @@ def _pinv_jvp(rtol, hermitian, primals, tangents):
9971011 return p , p_dot
9981012
9991013
1014+ @export
10001015@jit
10011016def inv (a : ArrayLike ) -> Array :
10021017 """Return the inverse of a square matrix
@@ -1057,6 +1072,7 @@ def inv(a: ArrayLike) -> Array:
10571072 arr , lax .broadcast (jnp .eye (arr .shape [- 1 ], dtype = arr .dtype ), arr .shape [:- 2 ]))
10581073
10591074
1075+ @export
10601076@partial (jit , static_argnames = ('ord' , 'axis' , 'keepdims' ))
10611077def norm (x : ArrayLike , ord : int | str | None = None ,
10621078 axis : None | tuple [int , ...] | int = None ,
@@ -1222,6 +1238,7 @@ def qr(a: ArrayLike, mode: Literal["r"]) -> Array: ...
12221238@overload
12231239def qr (a : ArrayLike , mode : str = "reduced" ) -> Array | QRResult : ...
12241240
1241+ @export
12251242@partial (jit , static_argnames = ('mode' ,))
12261243def qr (a : ArrayLike , mode : str = "reduced" ) -> Array | QRResult :
12271244 """Compute the QR decomposition of an array
@@ -1305,6 +1322,7 @@ def qr(a: ArrayLike, mode: str = "reduced") -> Array | QRResult:
13051322 return QRResult (q , r )
13061323
13071324
1325+ @export
13081326@jit
13091327def solve (a : ArrayLike , b : ArrayLike ) -> Array :
13101328 """Solve a linear system of equations
@@ -1408,6 +1426,7 @@ def _lstsq(a: ArrayLike, b: ArrayLike, rcond: float | None, *,
14081426_jit_lstsq = jit (partial (_lstsq , numpy_resid = False ))
14091427
14101428
1429+ @export
14111430def lstsq (a : ArrayLike , b : ArrayLike , rcond : float | None = None , * ,
14121431 numpy_resid : bool = False ) -> tuple [Array , Array , Array , Array ]:
14131432 """
@@ -1448,6 +1467,7 @@ def lstsq(a: ArrayLike, b: ArrayLike, rcond: float | None = None, *,
14481467 return _jit_lstsq (a , b , rcond )
14491468
14501469
1470+ @export
14511471def cross (x1 : ArrayLike , x2 : ArrayLike , / , * , axis = - 1 ):
14521472 r"""Compute the cross-product of two 3D vectors
14531473
@@ -1493,6 +1513,7 @@ def cross(x1: ArrayLike, x2: ArrayLike, /, *, axis=-1):
14931513 return jnp .cross (x1 , x2 , axis = axis )
14941514
14951515
1516+ @export
14961517def outer (x1 : ArrayLike , x2 : ArrayLike , / ) -> Array :
14971518 """Compute the outer product of two 1-dimensional arrays.
14981519
@@ -1523,6 +1544,7 @@ def outer(x1: ArrayLike, x2: ArrayLike, /) -> Array:
15231544 return x1 [:, None ] * x2 [None , :]
15241545
15251546
1547+ @export
15261548def matrix_norm (x : ArrayLike , / , * , keepdims : bool = False , ord : str = 'fro' ) -> Array :
15271549 """Compute the norm of a matrix or stack of matrices.
15281550
@@ -1553,6 +1575,7 @@ def matrix_norm(x: ArrayLike, /, *, keepdims: bool = False, ord: str = 'fro') ->
15531575 return norm (x , ord = ord , keepdims = keepdims , axis = (- 2 , - 1 ))
15541576
15551577
1578+ @export
15561579def matrix_transpose (x : ArrayLike , / ) -> Array :
15571580 """Transpose a matrix or stack of matrices.
15581581
@@ -1608,6 +1631,7 @@ def matrix_transpose(x: ArrayLike, /) -> Array:
16081631 return jax .lax .transpose (x_arr , (* range (ndim - 2 ), ndim - 1 , ndim - 2 ))
16091632
16101633
1634+ @export
16111635def vector_norm (x : ArrayLike , / , * , axis : int | None = None , keepdims : bool = False ,
16121636 ord : int | str = 2 ) -> Array :
16131637 """Compute the vector norm of a vector or batch of vectors.
@@ -1652,6 +1676,7 @@ def vector_norm(x: ArrayLike, /, *, axis: int | None = None, keepdims: bool = Fa
16521676 return norm (x , axis = axis , keepdims = keepdims , ord = ord )
16531677
16541678
1679+ @export
16551680def vecdot (x1 : ArrayLike , x2 : ArrayLike , / , * , axis : int = - 1 ,
16561681 precision : PrecisionLike = None ,
16571682 preferred_element_type : DTypeLike | None = None ) -> Array :
@@ -1702,6 +1727,7 @@ def vecdot(x1: ArrayLike, x2: ArrayLike, /, *, axis: int = -1,
17021727 preferred_element_type = preferred_element_type )
17031728
17041729
1730+ @export
17051731def matmul (x1 : ArrayLike , x2 : ArrayLike , / , * ,
17061732 precision : PrecisionLike = None ,
17071733 preferred_element_type : DTypeLike | None = None ) -> Array :
@@ -1762,6 +1788,7 @@ def matmul(x1: ArrayLike, x2: ArrayLike, /, *,
17621788 preferred_element_type = preferred_element_type )
17631789
17641790
1791+ @export
17651792def tensordot (x1 : ArrayLike , x2 : ArrayLike , / , * ,
17661793 axes : int | tuple [Sequence [int ], Sequence [int ]] = 2 ,
17671794 precision : PrecisionLike = None ,
@@ -1843,6 +1870,7 @@ def tensordot(x1: ArrayLike, x2: ArrayLike, /, *,
18431870 preferred_element_type = preferred_element_type )
18441871
18451872
1873+ @export
18461874def svdvals (x : ArrayLike , / ) -> Array :
18471875 """Compute the singular values of a matrix.
18481876
@@ -1867,6 +1895,7 @@ def svdvals(x: ArrayLike, /) -> Array:
18671895 return svd (x , compute_uv = False , hermitian = False )
18681896
18691897
1898+ @export
18701899def diagonal (x : ArrayLike , / , * , offset : int = 0 ) -> Array :
18711900 """Extract the diagonal of an matrix or stack of matrices.
18721901
@@ -1907,6 +1936,7 @@ def diagonal(x: ArrayLike, /, *, offset: int = 0) -> Array:
19071936 return jnp .diagonal (x , offset = offset , axis1 = - 2 , axis2 = - 1 )
19081937
19091938
1939+ @export
19101940def tensorinv (a : ArrayLike , ind : int = 2 ) -> Array :
19111941 """Compute the tensor inverse of an array.
19121942
@@ -1949,6 +1979,7 @@ def tensorinv(a: ArrayLike, ind: int = 2) -> Array:
19491979 return inv (arr .reshape (flatshape )).reshape (* batch_shape , * contracting_shape )
19501980
19511981
1982+ @export
19521983def tensorsolve (a : ArrayLike , b : ArrayLike , axes : tuple [int , ...] | None = None ) -> Array :
19531984 """Solve the tensor equation a x = b for x.
19541985
@@ -1998,6 +2029,7 @@ def tensorsolve(a: ArrayLike, b: ArrayLike, axes: tuple[int, ...] | None = None)
19982029 return solve (a_arr , b_arr .ravel ()).reshape (out_shape )
19992030
20002031
2032+ @export
20012033def multi_dot (arrays : Sequence [ArrayLike ], * , precision : PrecisionLike = None ) -> Array :
20022034 """Efficiently compute matrix products between a sequence of arrays.
20032035
@@ -2090,6 +2122,7 @@ def multi_dot(arrays: Sequence[ArrayLike], *, precision: PrecisionLike = None) -
20902122 optimize = 'optimal' , precision = precision )
20912123
20922124
2125+ @export
20932126@partial (jit , static_argnames = ['p' ])
20942127def cond (x : ArrayLike , p = None ):
20952128 """Compute the condition number of a matrix.
@@ -2149,6 +2182,7 @@ def cond(x: ArrayLike, p=None):
21492182 return jnp .where (ufuncs .isnan (r ) & ~ ufuncs .isnan (x ).any (axis = (- 2 , - 1 )), jnp .inf , r )
21502183
21512184
2185+ @export
21522186def trace (x : ArrayLike , / , * ,
21532187 offset : int = 0 , dtype : DTypeLike | None = None ) -> Array :
21542188 """Compute the trace of a matrix.
0 commit comments