Skip to content

Commit 5a1c4c5

Browse files
Merge pull request jax-ml#25338 from carlosgmartin:fix_numpy_linalg_matrix_norm_ord_type_annotation
PiperOrigin-RevId: 704245037
2 parents d474fed + efa35ea commit 5a1c4c5

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

jax/_src/numpy/linalg.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1517,7 +1517,7 @@ def outer(x1: ArrayLike, x2: ArrayLike, /) -> Array:
15171517

15181518

15191519
@export
1520-
def matrix_norm(x: ArrayLike, /, *, keepdims: bool = False, ord: str = 'fro') -> Array:
1520+
def matrix_norm(x: ArrayLike, /, *, keepdims: bool = False, ord: str | int = 'fro') -> Array:
15211521
"""Compute the norm of a matrix or stack of matrices.
15221522
15231523
JAX implementation of :func:`numpy.linalg.matrix_norm`

0 commit comments

Comments
 (0)