Skip to content

Commit 6d94ae3

Browse files
Improve docs for jnp.angle and jnp.flip
1 parent 55d0f5e commit 6d94ae3

File tree

1 file changed

+93
-2
lines changed

1 file changed

+93
-2
lines changed

jax/_src/numpy/lax_numpy.py

Lines changed: 93 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -716,8 +716,62 @@ def rot90(m: ArrayLike, k: int = 1, axes: tuple[int, int] = (0, 1)) -> Array:
716716
return flip(transpose(m, perm), ax2)
717717

718718

719-
@util.implements(np.flip, lax_description=_ARRAY_VIEW_DOC)
720719
def flip(m: ArrayLike, axis: int | Sequence[int] | None = None) -> Array:
720+
"""Reverse the order of elements of an array along the given axis.
721+
722+
JAX implementation of :func:`numpy.flip`.
723+
724+
Args:
725+
m: Array.
726+
axis: integer or sequence of integers. Specifies along which axis or axes
727+
should the array elements be reversed. Default is ``None``, which flips
728+
along all axes.
729+
730+
Returns:
731+
An array with the elements in reverse order along ``axis``.
732+
733+
See Also:
734+
- :func:`jax.numpy.fliplr`: reverse the order along axis 1 (left/right)
735+
- :func:`jax.numpy.flipud`: reverse the order along axis 0 (up/down)
736+
737+
Example:
738+
>>> x1 = jnp.array([[1, 2],
739+
... [3, 4]])
740+
>>> jnp.flip(x1)
741+
Array([[4, 3],
742+
[2, 1]], dtype=int32)
743+
744+
If ``axis`` is specified with an integer, then ``jax.numpy.flip`` reverses
745+
the array along that particular axis only.
746+
747+
>>> jnp.flip(x1, axis=1)
748+
Array([[2, 1],
749+
[4, 3]], dtype=int32)
750+
751+
>>> x2 = jnp.arange(1, 9).reshape(2, 2, 2)
752+
>>> x2
753+
Array([[[1, 2],
754+
[3, 4]],
755+
<BLANKLINE>
756+
[[5, 6],
757+
[7, 8]]], dtype=int32)
758+
>>> jnp.flip(x2)
759+
Array([[[8, 7],
760+
[6, 5]],
761+
<BLANKLINE>
762+
[[4, 3],
763+
[2, 1]]], dtype=int32)
764+
765+
When ``axis`` is specified with a sequence of integers, then
766+
``jax.numpy.flip`` reverses the array along the specified axes.
767+
768+
>>> jnp.flip(x2, axis=[1, 2])
769+
Array([[[4, 3],
770+
[2, 1]],
771+
<BLANKLINE>
772+
[[8, 7],
773+
[6, 5]]], dtype=int32)
774+
"""
721775
util.check_arraylike("flip", m)
722776
return _flip(asarray(m), reductions._ensure_optional_axes(axis))
723777

@@ -752,9 +806,46 @@ def isreal(x: ArrayLike) -> Array:
752806
i = ufuncs.imag(x)
753807
return lax.eq(i, _lax_const(i, 0))
754808

755-
@util.implements(np.angle)
809+
756810
@partial(jit, static_argnames=['deg'])
757811
def angle(z: ArrayLike, deg: bool = False) -> Array:
812+
"""Return the angle of a complex valued number or array.
813+
814+
JAX implementation of :func:`numpy.angle`.
815+
816+
Args:
817+
z: A complex number or an array of complex numbers.
818+
deg: Boolean. If ``True``, returns the result in degrees else returns
819+
in radians. Default is ``False``.
820+
821+
Returns:
822+
An array of counterclockwise angle of each element of ``z``, with the same
823+
shape as ``z`` of dtype float.
824+
825+
Example:
826+
827+
If ``z`` is a number
828+
829+
>>> z1 = 2+3j
830+
>>> jnp.angle(z1)
831+
Array(0.98279375, dtype=float32, weak_type=True)
832+
833+
If ``z`` is an array
834+
835+
>>> z2 = jnp.array([[1+3j, 2-5j],
836+
... [4-3j, 3+2j]])
837+
>>> with jnp.printoptions(precision=2, suppress=True):
838+
... print(jnp.angle(z2))
839+
[[ 1.25 -1.19]
840+
[-0.64 0.59]]
841+
842+
If ``deg=True``.
843+
844+
>>> with jnp.printoptions(precision=2, suppress=True):
845+
... print(jnp.angle(z2, deg=True))
846+
[[ 71.57 -68.2 ]
847+
[-36.87 33.69]]
848+
"""
758849
re = ufuncs.real(z)
759850
im = ufuncs.imag(z)
760851
dtype = _dtype(re)

0 commit comments

Comments
 (0)