Skip to content

Commit 5e9b1b0

Browse files
committed
Add matrix_norm and vector_norm
1 parent 422910b commit 5e9b1b0

File tree

2 files changed

+57
-1
lines changed

2 files changed

+57
-1
lines changed

numpy_array_api_compat/__init__.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,13 @@
1919
"""
2020

2121
from numpy import *
22-
from . import linalg
2322

2423
# These imports may overwrite names from the import * above.
2524
from ._aliases import *
25+
26+
# Don't know why, but we have to do this to import linalg. If we instead do
27+
#
28+
# from . import linalg
29+
#
30+
# It doesn't overwrite np.linalg from above.
31+
import numpy_array_api_compat.linalg

numpy_array_api_compat/linalg.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,51 @@
1+
from __future__ import annotations
2+
3+
from typing import TYPE_CHECKING
4+
if TYPE_CHECKING:
5+
from ._typing import Literal, Optional, Tuple, Union
6+
from numpy import ndarray
7+
8+
import numpy as np
9+
from numpy.core.numeric import normalize_axis_tuple
10+
11+
def matrix_norm(x: ndarray, /, *, keepdims: bool = False, ord: Optional[Union[int, float, Literal['fro', 'nuc']]] = 'fro') -> ndarray:
12+
return np.linalg.norm(x, axis=(-2, -1), keepdims=keepdims, ord=ord)
13+
14+
def vector_norm(x: ndarray, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False, ord: Optional[Union[int, float]] = 2) -> ndarray:
15+
# np.linalg.norm tries to do a matrix norm whenever axis is a 2-tuple or
16+
# when axis=None and the input is 2-D, so to force a vector norm, we make
17+
# it so the input is 1-D (for axis=None), or reshape so that norm is done
18+
# on a single dimension.
19+
if axis is None:
20+
# Note: np.linalg.norm() doesn't handle 0-D arrays
21+
x = x.ravel()
22+
_axis = 0
23+
elif isinstance(axis, tuple):
24+
# Note: The axis argument supports any number of axes, whereas
25+
# np.linalg.norm() only supports a single axis for vector norm.
26+
normalized_axis = normalize_axis_tuple(axis, x.ndim)
27+
rest = tuple(i for i in range(x.ndim) if i not in normalized_axis)
28+
newshape = axis + rest
29+
x = np.transpose(x, newshape).reshape(
30+
(np.prod([x.shape[i] for i in axis], dtype=int), *[x.shape[i] for i in rest]))
31+
_axis = 0
32+
else:
33+
_axis = axis
34+
35+
res = np.linalg.norm(x, axis=_axis, ord=ord)
36+
37+
if keepdims:
38+
# We can't reuse np.linalg.norm(keepdims) because of the reshape hacks
39+
# above to avoid matrix norm logic.
40+
shape = list(x.shape)
41+
_axis = normalize_axis_tuple(range(x.ndim) if axis is None else axis, x.ndim)
42+
for i in _axis:
43+
shape[i] = 1
44+
res = np.reshape(res, tuple(shape))
45+
46+
return res
47+
148
from numpy.linalg import *
49+
from numpy.linalg import __all__ as linalg_all
50+
__all__ = linalg_all.copy()
51+
__all__ += ['matrix_norm', 'vector_norm']

0 commit comments

Comments
 (0)