@@ -716,8 +716,62 @@ def rot90(m: ArrayLike, k: int = 1, axes: tuple[int, int] = (0, 1)) -> Array:
716
716
return flip (transpose (m , perm ), ax2 )
717
717
718
718
719
- @util .implements (np .flip , lax_description = _ARRAY_VIEW_DOC )
720
719
def flip (m : ArrayLike , axis : int | Sequence [int ] | None = None ) -> Array :
720
+ """Reverse the order of elements of an array along the given axis.
721
+
722
+ JAX implementation of :func:`numpy.flip`.
723
+
724
+ Args:
725
+ m: Array.
726
+ axis: integer or sequence of integers. Specifies along which axis or axes
727
+ should the array elements be reversed. Default is ``None``, which flips
728
+ along all axes.
729
+
730
+ Returns:
731
+ An array with the elements in reverse order along ``axis``.
732
+
733
+ See Also:
734
+ - :func:`jax.numpy.fliplr`: reverse the order along axis 1 (left/right)
735
+ - :func:`jax.numpy.flipud`: reverse the order along axis 0 (up/down)
736
+
737
+ Example:
738
+ >>> x1 = jnp.array([[1, 2],
739
+ ... [3, 4]])
740
+ >>> jnp.flip(x1)
741
+ Array([[4, 3],
742
+ [2, 1]], dtype=int32)
743
+
744
+ If ``axis`` is specified with an integer, then ``jax.numpy.flip`` reverses
745
+ the array along that particular axis only.
746
+
747
+ >>> jnp.flip(x1, axis=1)
748
+ Array([[2, 1],
749
+ [4, 3]], dtype=int32)
750
+
751
+ >>> x2 = jnp.arange(1, 9).reshape(2, 2, 2)
752
+ >>> x2
753
+ Array([[[1, 2],
754
+ [3, 4]],
755
+ <BLANKLINE>
756
+ [[5, 6],
757
+ [7, 8]]], dtype=int32)
758
+ >>> jnp.flip(x2)
759
+ Array([[[8, 7],
760
+ [6, 5]],
761
+ <BLANKLINE>
762
+ [[4, 3],
763
+ [2, 1]]], dtype=int32)
764
+
765
+ When ``axis`` is specified with a sequence of integers, then
766
+ ``jax.numpy.flip`` reverses the array along the specified axes.
767
+
768
+ >>> jnp.flip(x2, axis=[1, 2])
769
+ Array([[[4, 3],
770
+ [2, 1]],
771
+ <BLANKLINE>
772
+ [[8, 7],
773
+ [6, 5]]], dtype=int32)
774
+ """
721
775
util .check_arraylike ("flip" , m )
722
776
return _flip (asarray (m ), reductions ._ensure_optional_axes (axis ))
723
777
@@ -752,9 +806,46 @@ def isreal(x: ArrayLike) -> Array:
752
806
i = ufuncs .imag (x )
753
807
return lax .eq (i , _lax_const (i , 0 ))
754
808
755
- @ util . implements ( np . angle )
809
+
756
810
@partial (jit , static_argnames = ['deg' ])
757
811
def angle (z : ArrayLike , deg : bool = False ) -> Array :
812
+ """Return the angle of a complex valued number or array.
813
+
814
+ JAX implementation of :func:`numpy.angle`.
815
+
816
+ Args:
817
+ z: A complex number or an array of complex numbers.
818
+ deg: Boolean. If ``True``, returns the result in degrees else returns
819
+ in radians. Default is ``False``.
820
+
821
+ Returns:
822
+ An array of counterclockwise angle of each element of ``z``, with the same
823
+ shape as ``z`` of dtype float.
824
+
825
+ Example:
826
+
827
+ If ``z`` is a number
828
+
829
+ >>> z1 = 2+3j
830
+ >>> jnp.angle(z1)
831
+ Array(0.98279375, dtype=float32, weak_type=True)
832
+
833
+ If ``z`` is an array
834
+
835
+ >>> z2 = jnp.array([[1+3j, 2-5j],
836
+ ... [4-3j, 3+2j]])
837
+ >>> with jnp.printoptions(precision=2, suppress=True):
838
+ ... print(jnp.angle(z2))
839
+ [[ 1.25 -1.19]
840
+ [-0.64 0.59]]
841
+
842
+ If ``deg=True``.
843
+
844
+ >>> with jnp.printoptions(precision=2, suppress=True):
845
+ ... print(jnp.angle(z2, deg=True))
846
+ [[ 71.57 -68.2 ]
847
+ [-36.87 33.69]]
848
+ """
758
849
re = ufuncs .real (z )
759
850
im = ufuncs .imag (z )
760
851
dtype = _dtype (re )
0 commit comments