@@ -1159,35 +1159,7 @@ def norm(x: ArrayLike, ord: int | str | None = None,
11591159
11601160 num_axes = len (axis )
11611161 if num_axes == 1 :
1162- if ord is None or ord == 2 :
1163- return ufuncs .sqrt (reductions .sum (ufuncs .real (x * ufuncs .conj (x )), axis = axis ,
1164- keepdims = keepdims ))
1165- elif ord == jnp .inf :
1166- return reductions .amax (ufuncs .abs (x ), axis = axis , keepdims = keepdims )
1167- elif ord == - jnp .inf :
1168- return reductions .amin (ufuncs .abs (x ), axis = axis , keepdims = keepdims )
1169- elif ord == 0 :
1170- return reductions .sum (x != 0 , dtype = jnp .finfo (lax .dtype (x )).dtype ,
1171- axis = axis , keepdims = keepdims )
1172- elif ord == 1 :
1173- # Numpy has a special case for ord == 1 as an optimization. We don't
1174- # really need the optimization (XLA could do it for us), but the Numpy
1175- # code has slightly different type promotion semantics, so we need a
1176- # special case too.
1177- return reductions .sum (ufuncs .abs (x ), axis = axis , keepdims = keepdims )
1178- elif isinstance (ord , str ):
1179- msg = f"Invalid order '{ ord } ' for vector norm."
1180- if ord == "inf" :
1181- msg += "Use 'jax.numpy.inf' instead."
1182- if ord == "-inf" :
1183- msg += "Use '-jax.numpy.inf' instead."
1184- raise ValueError (msg )
1185- else :
1186- abs_x = ufuncs .abs (x )
1187- ord_arr = lax_internal ._const (abs_x , ord )
1188- ord_inv = lax_internal ._const (abs_x , 1. / ord_arr )
1189- out = reductions .sum (abs_x ** ord_arr , axis = axis , keepdims = keepdims )
1190- return ufuncs .power (out , ord_inv )
1162+ return vector_norm (x , ord = 2 if ord is None else ord , axis = axis , keepdims = keepdims )
11911163
11921164 elif num_axes == 2 :
11931165 row_axis , col_axis = axis # pytype: disable=bad-unpacking
@@ -1632,7 +1604,7 @@ def matrix_transpose(x: ArrayLike, /) -> Array:
16321604
16331605
16341606@export
1635- def vector_norm (x : ArrayLike , / , * , axis : int | None = None , keepdims : bool = False ,
1607+ def vector_norm (x : ArrayLike , / , * , axis : int | tuple [ int , ...] | None = None , keepdims : bool = False ,
16361608 ord : int | str = 2 ) -> Array :
16371609 """Compute the vector norm of a vector or batch of vectors.
16381610
@@ -1668,13 +1640,35 @@ def vector_norm(x: ArrayLike, /, *, axis: int | None = None, keepdims: bool = Fa
16681640 Array([3.7416575, 9.486833 ], dtype=float32)
16691641 """
16701642 check_arraylike ('jnp.linalg.vector_norm' , x )
1671- if axis is None :
1672- result = norm (jnp .ravel (x ), ord = ord )
1673- if keepdims :
1674- result = lax .expand_dims (result , range (jnp .ndim (x )))
1675- return result
1676- return norm (x , axis = axis , keepdims = keepdims , ord = ord )
1677-
1643+ if ord is None or ord == 2 :
1644+ return ufuncs .sqrt (reductions .sum (ufuncs .real (x * ufuncs .conj (x )), axis = axis ,
1645+ keepdims = keepdims ))
1646+ elif ord == jnp .inf :
1647+ return reductions .amax (ufuncs .abs (x ), axis = axis , keepdims = keepdims )
1648+ elif ord == - jnp .inf :
1649+ return reductions .amin (ufuncs .abs (x ), axis = axis , keepdims = keepdims )
1650+ elif ord == 0 :
1651+ return reductions .sum (x != 0 , dtype = jnp .finfo (lax .dtype (x )).dtype ,
1652+ axis = axis , keepdims = keepdims )
1653+ elif ord == 1 :
1654+ # Numpy has a special case for ord == 1 as an optimization. We don't
1655+ # really need the optimization (XLA could do it for us), but the Numpy
1656+ # code has slightly different type promotion semantics, so we need a
1657+ # special case too.
1658+ return reductions .sum (ufuncs .abs (x ), axis = axis , keepdims = keepdims )
1659+ elif isinstance (ord , str ):
1660+ msg = f"Invalid order '{ ord } ' for vector norm."
1661+ if ord == "inf" :
1662+ msg += "Use 'jax.numpy.inf' instead."
1663+ if ord == "-inf" :
1664+ msg += "Use '-jax.numpy.inf' instead."
1665+ raise ValueError (msg )
1666+ else :
1667+ abs_x = ufuncs .abs (x )
1668+ ord_arr = lax_internal ._const (abs_x , ord )
1669+ ord_inv = lax_internal ._const (abs_x , 1. / ord_arr )
1670+ out = reductions .sum (abs_x ** ord_arr , axis = axis , keepdims = keepdims )
1671+ return ufuncs .power (out , ord_inv )
16781672
16791673@export
16801674def vecdot (x1 : ArrayLike , x2 : ArrayLike , / , * , axis : int = - 1 ,
0 commit comments