Skip to content

Commit d219439

Browse files
Merge pull request jax-ml#25011 from jakevdp:jnp-linalg-module
PiperOrigin-RevId: 698517512
2 parents 9d2f62f + 621e39d commit d219439

File tree

2 files changed

+36
-1
lines changed

2 files changed

+36
-1
lines changed

jax/_src/numpy/linalg.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,13 @@
3535
from jax._src.numpy import lax_numpy as jnp
3636
from jax._src.numpy import reductions, ufuncs
3737
from jax._src.numpy.util import promote_dtypes_inexact, check_arraylike
38-
from jax._src.util import canonicalize_axis
38+
from jax._src.util import canonicalize_axis, set_module
3939
from jax._src.typing import ArrayLike, Array, DTypeLike, DeprecatedArg
4040

4141

42+
export = set_module('jax.numpy.linalg')
43+
44+
4245
class EighResult(NamedTuple):
4346
eigenvalues: jax.Array
4447
eigenvectors: jax.Array
@@ -67,6 +70,7 @@ def _H(x: ArrayLike) -> Array:
6770
def _symmetrize(x: Array) -> Array: return (x + _H(x)) / 2
6871

6972

73+
@export
7074
@partial(jit, static_argnames=['upper'])
7175
def cholesky(a: ArrayLike, *, upper: bool = False) -> Array:
7276
"""Compute the Cholesky decomposition of a matrix.
@@ -191,6 +195,7 @@ def svd(
191195
...
192196

193197

198+
@export
194199
@partial(
195200
jit,
196201
static_argnames=(
@@ -311,6 +316,7 @@ def svd(
311316
)
312317

313318

319+
@export
314320
@partial(jit, static_argnames=('n',))
315321
def matrix_power(a: ArrayLike, n: int) -> Array:
316322
"""Raise a square matrix to an integer power.
@@ -392,6 +398,7 @@ def matrix_power(a: ArrayLike, n: int) -> Array:
392398
return result
393399

394400

401+
@export
395402
@jit
396403
def matrix_rank(
397404
M: ArrayLike, rtol: ArrayLike | None = None, *,
@@ -496,6 +503,7 @@ def _slogdet_qr(a: Array) -> tuple[Array, Array]:
496503
return sign_diag * sign_taus, log_abs_det
497504

498505

506+
@export
499507
@partial(jit, static_argnames=('method',))
500508
def slogdet(a: ArrayLike, *, method: str | None = None) -> SlogdetResult:
501509
"""
@@ -675,6 +683,7 @@ def _det_jvp(primals, tangents):
675683
return y, jnp.trace(z, axis1=-1, axis2=-2)
676684

677685

686+
@export
678687
@jit
679688
def det(a: ArrayLike) -> Array:
680689
"""
@@ -711,6 +720,7 @@ def det(a: ArrayLike) -> Array:
711720
raise ValueError(msg.format(a_shape))
712721

713722

723+
@export
714724
def eig(a: ArrayLike) -> tuple[Array, Array]:
715725
"""
716726
Compute the eigenvalues and eigenvectors of a square array.
@@ -756,6 +766,7 @@ def eig(a: ArrayLike) -> tuple[Array, Array]:
756766
return w, v
757767

758768

769+
@export
759770
@jit
760771
def eigvals(a: ArrayLike) -> Array:
761772
"""
@@ -793,6 +804,7 @@ def eigvals(a: ArrayLike) -> Array:
793804
compute_right_eigenvectors=False)[0]
794805

795806

807+
@export
796808
@partial(jit, static_argnames=('UPLO', 'symmetrize_input'))
797809
def eigh(a: ArrayLike, UPLO: str | None = None,
798810
symmetrize_input: bool = True) -> EighResult:
@@ -848,6 +860,7 @@ def eigh(a: ArrayLike, UPLO: str | None = None,
848860
return EighResult(w, v)
849861

850862

863+
@export
851864
@partial(jit, static_argnames=('UPLO',))
852865
def eigvalsh(a: ArrayLike, UPLO: str | None = 'L') -> Array:
853866
"""
@@ -884,6 +897,7 @@ def eigvalsh(a: ArrayLike, UPLO: str | None = 'L') -> Array:
884897

885898

886899
# TODO(micky774): deprecated 2024-5-14, remove wrapper after deprecation expires.
900+
@export
887901
def pinv(a: ArrayLike, rtol: ArrayLike | None = None,
888902
hermitian: bool = False, *,
889903
rcond: ArrayLike | DeprecatedArg | None = DeprecatedArg()) -> Array:
@@ -997,6 +1011,7 @@ def _pinv_jvp(rtol, hermitian, primals, tangents):
9971011
return p, p_dot
9981012

9991013

1014+
@export
10001015
@jit
10011016
def inv(a: ArrayLike) -> Array:
10021017
"""Return the inverse of a square matrix
@@ -1057,6 +1072,7 @@ def inv(a: ArrayLike) -> Array:
10571072
arr, lax.broadcast(jnp.eye(arr.shape[-1], dtype=arr.dtype), arr.shape[:-2]))
10581073

10591074

1075+
@export
10601076
@partial(jit, static_argnames=('ord', 'axis', 'keepdims'))
10611077
def norm(x: ArrayLike, ord: int | str | None = None,
10621078
axis: None | tuple[int, ...] | int = None,
@@ -1222,6 +1238,7 @@ def qr(a: ArrayLike, mode: Literal["r"]) -> Array: ...
12221238
@overload
12231239
def qr(a: ArrayLike, mode: str = "reduced") -> Array | QRResult: ...
12241240

1241+
@export
12251242
@partial(jit, static_argnames=('mode',))
12261243
def qr(a: ArrayLike, mode: str = "reduced") -> Array | QRResult:
12271244
"""Compute the QR decomposition of an array
@@ -1305,6 +1322,7 @@ def qr(a: ArrayLike, mode: str = "reduced") -> Array | QRResult:
13051322
return QRResult(q, r)
13061323

13071324

1325+
@export
13081326
@jit
13091327
def solve(a: ArrayLike, b: ArrayLike) -> Array:
13101328
"""Solve a linear system of equations
@@ -1408,6 +1426,7 @@ def _lstsq(a: ArrayLike, b: ArrayLike, rcond: float | None, *,
14081426
_jit_lstsq = jit(partial(_lstsq, numpy_resid=False))
14091427

14101428

1429+
@export
14111430
def lstsq(a: ArrayLike, b: ArrayLike, rcond: float | None = None, *,
14121431
numpy_resid: bool = False) -> tuple[Array, Array, Array, Array]:
14131432
"""
@@ -1448,6 +1467,7 @@ def lstsq(a: ArrayLike, b: ArrayLike, rcond: float | None = None, *,
14481467
return _jit_lstsq(a, b, rcond)
14491468

14501469

1470+
@export
14511471
def cross(x1: ArrayLike, x2: ArrayLike, /, *, axis=-1):
14521472
r"""Compute the cross-product of two 3D vectors
14531473
@@ -1493,6 +1513,7 @@ def cross(x1: ArrayLike, x2: ArrayLike, /, *, axis=-1):
14931513
return jnp.cross(x1, x2, axis=axis)
14941514

14951515

1516+
@export
14961517
def outer(x1: ArrayLike, x2: ArrayLike, /) -> Array:
14971518
"""Compute the outer product of two 1-dimensional arrays.
14981519
@@ -1523,6 +1544,7 @@ def outer(x1: ArrayLike, x2: ArrayLike, /) -> Array:
15231544
return x1[:, None] * x2[None, :]
15241545

15251546

1547+
@export
15261548
def matrix_norm(x: ArrayLike, /, *, keepdims: bool = False, ord: str = 'fro') -> Array:
15271549
"""Compute the norm of a matrix or stack of matrices.
15281550
@@ -1553,6 +1575,7 @@ def matrix_norm(x: ArrayLike, /, *, keepdims: bool = False, ord: str = 'fro') ->
15531575
return norm(x, ord=ord, keepdims=keepdims, axis=(-2, -1))
15541576

15551577

1578+
@export
15561579
def matrix_transpose(x: ArrayLike, /) -> Array:
15571580
"""Transpose a matrix or stack of matrices.
15581581
@@ -1608,6 +1631,7 @@ def matrix_transpose(x: ArrayLike, /) -> Array:
16081631
return jax.lax.transpose(x_arr, (*range(ndim - 2), ndim - 1, ndim - 2))
16091632

16101633

1634+
@export
16111635
def vector_norm(x: ArrayLike, /, *, axis: int | None = None, keepdims: bool = False,
16121636
ord: int | str = 2) -> Array:
16131637
"""Compute the vector norm of a vector or batch of vectors.
@@ -1652,6 +1676,7 @@ def vector_norm(x: ArrayLike, /, *, axis: int | None = None, keepdims: bool = Fa
16521676
return norm(x, axis=axis, keepdims=keepdims, ord=ord)
16531677

16541678

1679+
@export
16551680
def vecdot(x1: ArrayLike, x2: ArrayLike, /, *, axis: int = -1,
16561681
precision: PrecisionLike = None,
16571682
preferred_element_type: DTypeLike | None = None) -> Array:
@@ -1702,6 +1727,7 @@ def vecdot(x1: ArrayLike, x2: ArrayLike, /, *, axis: int = -1,
17021727
preferred_element_type=preferred_element_type)
17031728

17041729

1730+
@export
17051731
def matmul(x1: ArrayLike, x2: ArrayLike, /, *,
17061732
precision: PrecisionLike = None,
17071733
preferred_element_type: DTypeLike | None = None) -> Array:
@@ -1762,6 +1788,7 @@ def matmul(x1: ArrayLike, x2: ArrayLike, /, *,
17621788
preferred_element_type=preferred_element_type)
17631789

17641790

1791+
@export
17651792
def tensordot(x1: ArrayLike, x2: ArrayLike, /, *,
17661793
axes: int | tuple[Sequence[int], Sequence[int]] = 2,
17671794
precision: PrecisionLike = None,
@@ -1843,6 +1870,7 @@ def tensordot(x1: ArrayLike, x2: ArrayLike, /, *,
18431870
preferred_element_type=preferred_element_type)
18441871

18451872

1873+
@export
18461874
def svdvals(x: ArrayLike, /) -> Array:
18471875
"""Compute the singular values of a matrix.
18481876
@@ -1867,6 +1895,7 @@ def svdvals(x: ArrayLike, /) -> Array:
18671895
return svd(x, compute_uv=False, hermitian=False)
18681896

18691897

1898+
@export
18701899
def diagonal(x: ArrayLike, /, *, offset: int = 0) -> Array:
18711900
"""Extract the diagonal of an matrix or stack of matrices.
18721901
@@ -1907,6 +1936,7 @@ def diagonal(x: ArrayLike, /, *, offset: int = 0) -> Array:
19071936
return jnp.diagonal(x, offset=offset, axis1=-2, axis2=-1)
19081937

19091938

1939+
@export
19101940
def tensorinv(a: ArrayLike, ind: int = 2) -> Array:
19111941
"""Compute the tensor inverse of an array.
19121942
@@ -1949,6 +1979,7 @@ def tensorinv(a: ArrayLike, ind: int = 2) -> Array:
19491979
return inv(arr.reshape(flatshape)).reshape(*batch_shape, *contracting_shape)
19501980

19511981

1982+
@export
19521983
def tensorsolve(a: ArrayLike, b: ArrayLike, axes: tuple[int, ...] | None = None) -> Array:
19531984
"""Solve the tensor equation a x = b for x.
19541985
@@ -1998,6 +2029,7 @@ def tensorsolve(a: ArrayLike, b: ArrayLike, axes: tuple[int, ...] | None = None)
19982029
return solve(a_arr, b_arr.ravel()).reshape(out_shape)
19992030

20002031

2032+
@export
20012033
def multi_dot(arrays: Sequence[ArrayLike], *, precision: PrecisionLike = None) -> Array:
20022034
"""Efficiently compute matrix products between a sequence of arrays.
20032035
@@ -2090,6 +2122,7 @@ def multi_dot(arrays: Sequence[ArrayLike], *, precision: PrecisionLike = None) -
20902122
optimize='optimal', precision=precision)
20912123

20922124

2125+
@export
20932126
@partial(jit, static_argnames=['p'])
20942127
def cond(x: ArrayLike, p=None):
20952128
"""Compute the condition number of a matrix.
@@ -2149,6 +2182,7 @@ def cond(x: ArrayLike, p=None):
21492182
return jnp.where(ufuncs.isnan(r) & ~ufuncs.isnan(x).any(axis=(-2, -1)), jnp.inf, r)
21502183

21512184

2185+
@export
21522186
def trace(x: ArrayLike, /, *,
21532187
offset: int = 0, dtype: DTypeLike | None = None) -> Array:
21542188
"""Compute the trace of a matrix.

tests/package_structure_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ class PackageStructureTest(jtu.JaxTestCase):
4040
"number", "object_", "printoptions", "save", "savez", "set_printoptions",
4141
"shape", "signedinteger", "size", "s_", "unsignedinteger", "ComplexWarning"]
4242
),
43+
_mod("jax.numpy.linalg"),
4344
_mod("jax.nn.initializers"),
4445
_mod(
4546
"jax.tree_util",

0 commit comments

Comments
 (0)