Skip to content

Commit a1140e9

Browse files
committed
Better docs for jnp.cross
1 parent ca2d158 commit a1140e9

File tree

1 file changed

+73
-1
lines changed

1 file changed

+73
-1
lines changed

jax/_src/numpy/lax_numpy.py

Lines changed: 73 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9511,10 +9511,82 @@ def outer(a: ArrayLike, b: ArrayLike, out: None = None) -> Array:
95119511
a, b = util.promote_dtypes(a, b)
95129512
return ravel(a)[:, None] * ravel(b)[None, :]
95139513

9514-
@util.implements(np.cross)
9514+
95159515
@partial(jit, static_argnames=('axisa', 'axisb', 'axisc', 'axis'))
95169516
def cross(a, b, axisa: int = -1, axisb: int = -1, axisc: int = -1,
95179517
axis: int | None = None):
9518+
r"""Compute the (batched) cross product of two arrays.
9519+
9520+
JAX implementation of :func:`numpy.cross`.
9521+
9522+
This computes the 2-dimensional or 3-dimensional cross product,
9523+
9524+
.. math::
9525+
9526+
c = a \times b
9527+
9528+
In 3 dimensions, ``c`` is a length-3 array. In 2 dimensions, ``c`` is
9529+
a scalar.
9530+
9531+
Args:
9532+
a: N-dimensional array. ``a.shape[axisa]`` indicates the dimension of
9533+
the cross product, and must be 2 or 3.
9534+
b: N-dimensional array. Must have ``b.shape[axisb] == a.shape[axisb]``,
9535+
and other dimensions of ``a`` and ``b`` must be broadcast compatible.
9536+
axisa: specicy the axis of ``a`` along which to compute the cross product.
9537+
axisb: specicy the axis of ``b`` along which to compute the cross product.
9538+
axisc: specicy the axis of ``c`` along which the cross product result
9539+
will be stored.
9540+
axis: if specified, this overrides ``axisa``, ``axisb``, and ``axisc``
9541+
with a single value.
9542+
9543+
Returns:
9544+
The array ``c`` containing the (batched) cross product of ``a`` and ``b``
9545+
along the specified axes.
9546+
9547+
See also:
9548+
- :func:`jax.numpy.linalg.cross`: an array API compatible function for
9549+
computing cross products over 3-vectors.
9550+
9551+
Examples:
9552+
A 2-dimensional cross product returns a scalar:
9553+
9554+
>>> a = jnp.array([1, 2])
9555+
>>> b = jnp.array([3, 4])
9556+
>>> jnp.cross(a, b)
9557+
Array(-2, dtype=int32)
9558+
9559+
A 3-dimensional cross product returns a length-3 vector:
9560+
9561+
>>> a = jnp.array([1, 2, 3])
9562+
>>> b = jnp.array([4, 5, 6])
9563+
>>> jnp.cross(a, b)
9564+
Array([-3, 6, -3], dtype=int32)
9565+
9566+
With multi-dimensional inputs, the cross-product is computed along
9567+
the last axis by default. Here's a batched 3-dimensional cross
9568+
product, operating on the rows of the inputs:
9569+
9570+
>>> a = jnp.array([[1, 2, 3],
9571+
... [3, 4, 3]])
9572+
>>> b = jnp.array([[2, 3, 2],
9573+
... [4, 5, 6]])
9574+
>>> jnp.cross(a, b)
9575+
Array([[-5, 4, -1],
9576+
[ 9, -6, -1]], dtype=int32)
9577+
9578+
Specifying axis=0 makes this a batched 2-dimensional cross product,
9579+
operating on the columns of the inputs:
9580+
9581+
>>> jnp.cross(a, b, axis=0)
9582+
Array([-2, -2, 12], dtype=int32)
9583+
9584+
Equivalently, we can independently specify the axis of the inputs ``a``
9585+
and ``b`` and the output ``c``:
9586+
9587+
>>> jnp.cross(a, b, axisa=0, axisb=0, axisc=0)
9588+
Array([-2, -2, 12], dtype=int32)
9589+
"""
95189590
# TODO(jakevdp): NumPy 2.0 deprecates 2D inputs. Follow suit here.
95199591
util.check_arraylike("cross", a, b)
95209592
if axis is not None:

0 commit comments

Comments
 (0)