Skip to content

Commit aaaee63

Browse files
committed
jnp.linalg.vector_norm: properly support multiple axes
1 parent 3f5f3e1 commit aaaee63

File tree

2 files changed

+57
-51
lines changed

2 files changed

+57
-51
lines changed

jax/_src/numpy/linalg.py

Lines changed: 31 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1159,35 +1159,7 @@ def norm(x: ArrayLike, ord: int | str | None = None,
11591159

11601160
num_axes = len(axis)
11611161
if num_axes == 1:
1162-
if ord is None or ord == 2:
1163-
return ufuncs.sqrt(reductions.sum(ufuncs.real(x * ufuncs.conj(x)), axis=axis,
1164-
keepdims=keepdims))
1165-
elif ord == jnp.inf:
1166-
return reductions.amax(ufuncs.abs(x), axis=axis, keepdims=keepdims)
1167-
elif ord == -jnp.inf:
1168-
return reductions.amin(ufuncs.abs(x), axis=axis, keepdims=keepdims)
1169-
elif ord == 0:
1170-
return reductions.sum(x != 0, dtype=jnp.finfo(lax.dtype(x)).dtype,
1171-
axis=axis, keepdims=keepdims)
1172-
elif ord == 1:
1173-
# Numpy has a special case for ord == 1 as an optimization. We don't
1174-
# really need the optimization (XLA could do it for us), but the Numpy
1175-
# code has slightly different type promotion semantics, so we need a
1176-
# special case too.
1177-
return reductions.sum(ufuncs.abs(x), axis=axis, keepdims=keepdims)
1178-
elif isinstance(ord, str):
1179-
msg = f"Invalid order '{ord}' for vector norm."
1180-
if ord == "inf":
1181-
msg += "Use 'jax.numpy.inf' instead."
1182-
if ord == "-inf":
1183-
msg += "Use '-jax.numpy.inf' instead."
1184-
raise ValueError(msg)
1185-
else:
1186-
abs_x = ufuncs.abs(x)
1187-
ord_arr = lax_internal._const(abs_x, ord)
1188-
ord_inv = lax_internal._const(abs_x, 1. / ord_arr)
1189-
out = reductions.sum(abs_x ** ord_arr, axis=axis, keepdims=keepdims)
1190-
return ufuncs.power(out, ord_inv)
1162+
return vector_norm(x, ord=2 if ord is None else ord, axis=axis, keepdims=keepdims)
11911163

11921164
elif num_axes == 2:
11931165
row_axis, col_axis = axis # pytype: disable=bad-unpacking
@@ -1632,7 +1604,7 @@ def matrix_transpose(x: ArrayLike, /) -> Array:
16321604

16331605

16341606
@export
1635-
def vector_norm(x: ArrayLike, /, *, axis: int | None = None, keepdims: bool = False,
1607+
def vector_norm(x: ArrayLike, /, *, axis: int | tuple[int, ...] | None = None, keepdims: bool = False,
16361608
ord: int | str = 2) -> Array:
16371609
"""Compute the vector norm of a vector or batch of vectors.
16381610
@@ -1668,13 +1640,35 @@ def vector_norm(x: ArrayLike, /, *, axis: int | None = None, keepdims: bool = Fa
16681640
Array([3.7416575, 9.486833 ], dtype=float32)
16691641
"""
16701642
check_arraylike('jnp.linalg.vector_norm', x)
1671-
if axis is None:
1672-
result = norm(jnp.ravel(x), ord=ord)
1673-
if keepdims:
1674-
result = lax.expand_dims(result, range(jnp.ndim(x)))
1675-
return result
1676-
return norm(x, axis=axis, keepdims=keepdims, ord=ord)
1677-
1643+
if ord is None or ord == 2:
1644+
return ufuncs.sqrt(reductions.sum(ufuncs.real(x * ufuncs.conj(x)), axis=axis,
1645+
keepdims=keepdims))
1646+
elif ord == jnp.inf:
1647+
return reductions.amax(ufuncs.abs(x), axis=axis, keepdims=keepdims)
1648+
elif ord == -jnp.inf:
1649+
return reductions.amin(ufuncs.abs(x), axis=axis, keepdims=keepdims)
1650+
elif ord == 0:
1651+
return reductions.sum(x != 0, dtype=jnp.finfo(lax.dtype(x)).dtype,
1652+
axis=axis, keepdims=keepdims)
1653+
elif ord == 1:
1654+
# Numpy has a special case for ord == 1 as an optimization. We don't
1655+
# really need the optimization (XLA could do it for us), but the Numpy
1656+
# code has slightly different type promotion semantics, so we need a
1657+
# special case too.
1658+
return reductions.sum(ufuncs.abs(x), axis=axis, keepdims=keepdims)
1659+
elif isinstance(ord, str):
1660+
msg = f"Invalid order '{ord}' for vector norm."
1661+
if ord == "inf":
1662+
msg += "Use 'jax.numpy.inf' instead."
1663+
if ord == "-inf":
1664+
msg += "Use '-jax.numpy.inf' instead."
1665+
raise ValueError(msg)
1666+
else:
1667+
abs_x = ufuncs.abs(x)
1668+
ord_arr = lax_internal._const(abs_x, ord)
1669+
ord_inv = lax_internal._const(abs_x, 1. / ord_arr)
1670+
out = reductions.sum(abs_x ** ord_arr, axis=axis, keepdims=keepdims)
1671+
return ufuncs.power(out, ord_inv)
16781672

16791673
@export
16801674
def vecdot(x1: ArrayLike, x2: ArrayLike, /, *, axis: int = -1,

tests/linalg_test.py

Lines changed: 26 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616

1717
from functools import partial
1818
import itertools
19+
from typing import Iterator
20+
from unittest import skipIf
1921

2022
import numpy as np
2123
import scipy
@@ -54,6 +56,20 @@ def _is_required_cuda_version_satisfied(cuda_version):
5456
return int(version.split()[-1]) >= cuda_version
5557

5658

59+
def _axis_for_ndim(ndim: int) -> Iterator[None | int | tuple[int, ...]]:
60+
"""
61+
Generate a range of valid axis arguments for a reduction over
62+
an array with a given number of dimensions.
63+
"""
64+
yield from (None, ())
65+
if ndim > 0:
66+
yield from (0, (-1,))
67+
if ndim > 1:
68+
yield from (1, (0, 1), (-1, 0))
69+
if ndim > 2:
70+
yield (-1, 0, 1)
71+
72+
5773
def osp_linalg_toeplitz(c: np.ndarray, r: np.ndarray | None = None) -> np.ndarray:
5874
"""scipy.linalg.toeplitz with v1.17+ batching semantics."""
5975
if scipy_version >= (1, 17, 0):
@@ -707,29 +723,25 @@ def testMatrixNorm(self, shape, dtype, keepdims, ord):
707723
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, tol=1e-3)
708724
self._CompileAndCheck(jnp_fn, args_maker)
709725

726+
@skipIf(jtu.numpy_version() < (2, 0, 0), "np.linalg.vector_norm requires NumPy 2.0")
710727
@jtu.sample_product(
711-
shape=[(3,), (3, 4), (2, 3, 4, 5)],
728+
[
729+
dict(shape=shape, axis=axis)
730+
for shape in [(3,), (3, 4), (2, 3, 4, 5)]
731+
for axis in _axis_for_ndim(len(shape))
732+
],
712733
dtype=float_types + complex_types,
713734
keepdims=[True, False],
714-
axis=[0, None],
715735
ord=[1, -1, 2, -2, np.inf, -np.inf],
716736
)
717737
def testVectorNorm(self, shape, dtype, keepdims, axis, ord):
718738
rng = jtu.rand_default(self.rng())
719739
args_maker = lambda: [rng(shape, dtype)]
720-
if jtu.numpy_version() < (2, 0, 0):
721-
def np_fn(x, *, ord, keepdims, axis):
722-
x = np.asarray(x)
723-
if axis is None:
724-
result = np_fn(x.ravel(), ord=ord, keepdims=False, axis=0)
725-
return np.reshape(result, (1,) * x.ndim) if keepdims else result
726-
return np.linalg.norm(x, ord=ord, keepdims=keepdims, axis=axis)
727-
else:
728-
np_fn = np.linalg.vector_norm
729-
np_fn = partial(np_fn, ord=ord, keepdims=keepdims, axis=axis)
740+
np_fn = partial(np.linalg.vector_norm, ord=ord, keepdims=keepdims, axis=axis)
730741
jnp_fn = partial(jnp.linalg.vector_norm, ord=ord, keepdims=keepdims, axis=axis)
731-
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, tol=1e-3)
732-
self._CompileAndCheck(jnp_fn, args_maker)
742+
tol = 1E-3 if jtu.test_device_matches(['tpu']) else None
743+
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, tol=tol)
744+
self._CompileAndCheck(jnp_fn, args_maker, tol=tol)
733745

734746
# jnp.linalg.vecdot is an alias of jnp.vecdot; do a minimal test here.
735747
@jtu.sample_product(

0 commit comments

Comments
 (0)