@@ -9511,10 +9511,82 @@ def outer(a: ArrayLike, b: ArrayLike, out: None = None) -> Array:
95119511 a , b = util .promote_dtypes (a , b )
95129512 return ravel (a )[:, None ] * ravel (b )[None , :]
95139513
9514- @ util . implements ( np . cross )
9514+
95159515@partial (jit , static_argnames = ('axisa' , 'axisb' , 'axisc' , 'axis' ))
95169516def cross (a , b , axisa : int = - 1 , axisb : int = - 1 , axisc : int = - 1 ,
95179517 axis : int | None = None ):
9518+ r"""Compute the (batched) cross product of two arrays.
9519+
9520+ JAX implementation of :func:`numpy.cross`.
9521+
9522+ This computes the 2-dimensional or 3-dimensional cross product,
9523+
9524+ .. math::
9525+
9526+ c = a \times b
9527+
9528+ In 3 dimensions, ``c`` is a length-3 array. In 2 dimensions, ``c`` is
9529+ a scalar.
9530+
9531+ Args:
9532+ a: N-dimensional array. ``a.shape[axisa]`` indicates the dimension of
9533+ the cross product, and must be 2 or 3.
9534+ b: N-dimensional array. Must have ``b.shape[axisb] == a.shape[axisb]``,
9535+ and other dimensions of ``a`` and ``b`` must be broadcast compatible.
9536+ axisa: specicy the axis of ``a`` along which to compute the cross product.
9537+ axisb: specicy the axis of ``b`` along which to compute the cross product.
9538+ axisc: specicy the axis of ``c`` along which the cross product result
9539+ will be stored.
9540+ axis: if specified, this overrides ``axisa``, ``axisb``, and ``axisc``
9541+ with a single value.
9542+
9543+ Returns:
9544+ The array ``c`` containing the (batched) cross product of ``a`` and ``b``
9545+ along the specified axes.
9546+
9547+ See also:
9548+ - :func:`jax.numpy.linalg.cross`: an array API compatible function for
9549+ computing cross products over 3-vectors.
9550+
9551+ Examples:
9552+ A 2-dimensional cross product returns a scalar:
9553+
9554+ >>> a = jnp.array([1, 2])
9555+ >>> b = jnp.array([3, 4])
9556+ >>> jnp.cross(a, b)
9557+ Array(-2, dtype=int32)
9558+
9559+ A 3-dimensional cross product returns a length-3 vector:
9560+
9561+ >>> a = jnp.array([1, 2, 3])
9562+ >>> b = jnp.array([4, 5, 6])
9563+ >>> jnp.cross(a, b)
9564+ Array([-3, 6, -3], dtype=int32)
9565+
9566+ With multi-dimensional inputs, the cross-product is computed along
9567+ the last axis by default. Here's a batched 3-dimensional cross
9568+ product, operating on the rows of the inputs:
9569+
9570+ >>> a = jnp.array([[1, 2, 3],
9571+ ... [3, 4, 3]])
9572+ >>> b = jnp.array([[2, 3, 2],
9573+ ... [4, 5, 6]])
9574+ >>> jnp.cross(a, b)
9575+ Array([[-5, 4, -1],
9576+ [ 9, -6, -1]], dtype=int32)
9577+
9578+ Specifying axis=0 makes this a batched 2-dimensional cross product,
9579+ operating on the columns of the inputs:
9580+
9581+ >>> jnp.cross(a, b, axis=0)
9582+ Array([-2, -2, 12], dtype=int32)
9583+
9584+ Equivalently, we can independently specify the axis of the inputs ``a``
9585+ and ``b`` and the output ``c``:
9586+
9587+ >>> jnp.cross(a, b, axisa=0, axisb=0, axisc=0)
9588+ Array([-2, -2, 12], dtype=int32)
9589+ """
95189590 # TODO(jakevdp): NumPy 2.0 deprecates 2D inputs. Follow suit here.
95199591 util .check_arraylike ("cross" , a , b )
95209592 if axis is not None :
0 commit comments