@@ -615,8 +615,23 @@ def tanh(x: ArrayLike) -> Array:
615615 """
616616 return tanh_p .bind (x )
617617
618+ @export
618619def logistic (x : ArrayLike ) -> Array :
619- r"""Elementwise logistic (sigmoid) function: :math:`\frac{1}{1 + e^{-x}}`."""
620+ r"""Elementwise logistic (sigmoid) function: :math:`\frac{1}{1 + e^{-x}}`.
621+
622+ There is no HLO logistic/sigmoid primitive, so this lowers to a sequence
623+ of HLO arithmetic operations.
624+
625+ Args:
626+ x: input array. Must have floating point or complex dtype.
627+
628+ Returns:
629+ Array of the same shape and dtype as ``x`` containing the element-wise
630+ logistic/sigmoid function.
631+
632+ See also:
633+ - :func:`jax.nn.sigmoid`: an alternative API for this functionality.
634+ """
620635 return logistic_p .bind (x )
621636
622637@export
@@ -1018,12 +1033,45 @@ def bitwise_xor(x: ArrayLike, y: ArrayLike) -> Array:
10181033 """
10191034 return xor_p .bind (x , y )
10201035
1036+ @export
10211037def population_count (x : ArrayLike ) -> Array :
1022- r"""Elementwise popcount, count the number of set bits in each element."""
1038+ r"""Elementwise popcount, count the number of set bits in each element.
1039+
1040+ This function lowers directly to the `stablehlo.popcnt`_ operation.
1041+
1042+ Args:
1043+ x: Input array. Must have integer dtype.
1044+
1045+ Returns:
1046+ An array of the same shape and dtype as ``x``, containing the number of
1047+ set bits in the input.
1048+
1049+ See also:
1050+ - :func:`jax.lax.clz`: Elementwise count leading zeros.
1051+ - :func:`jax.numpy.bitwise_count`: More flexible NumPy-style API for bit counts.
1052+
1053+ .. _stablehlo.popcnt: https://openxla.org/stablehlo/spec#popcnt
1054+ """
10231055 return population_count_p .bind (x )
10241056
1057+ @export
10251058def clz (x : ArrayLike ) -> Array :
1026- r"""Elementwise count-leading-zeros."""
1059+ r"""Elementwise count-leading-zeros.
1060+
1061+ This function lowers directly to the `stablehlo.count_leading_zeros`_ operation.
1062+
1063+ Args:
1064+ x: Input array. Must have integer dtype.
1065+
1066+ Returns:
1067+ An array of the same shape and dtype as ``x``, containing the number of
1068+ set bits in the input.
1069+
1070+ See also:
1071+ - :func:`jax.lax.population_count`: Count the number of set bits in each element.
1072+
1073+ .. _stablehlo.count_leading_zeros: https://openxla.org/stablehlo/spec#count_leading_zeros
1074+ """
10271075 return clz_p .bind (x )
10281076
10291077@export
@@ -1124,31 +1172,81 @@ def div(x: ArrayLike, y: ArrayLike) -> Array:
11241172 """
11251173 return div_p .bind (x , y )
11261174
1175+ @export
11271176def rem (x : ArrayLike , y : ArrayLike ) -> Array :
11281177 r"""Elementwise remainder: :math:`x \bmod y`.
11291178
1130- The sign of the result is taken from the dividend,
1131- and the absolute value of the result is always
1132- less than the divisor's absolute value.
1179+ This function lowers directly to the `stablehlo.remainder`_ operation.
1180+ The sign of the result is taken from the dividend, and the absolute value
1181+ of the result is always less than the divisor's absolute value.
11331182
1134- Integer division overflow
1135- (remainder by zero or remainder of INT_SMIN with -1)
1183+ Integer division overflow (remainder by zero or remainder of INT_SMIN with -1)
11361184 produces an implementation defined value.
1185+
1186+ Args:
1187+ x, y: Input arrays. Must have matching int or float dtypes. If neither
1188+ is a scalar, ``x`` and ``y`` must have the same number of dimensions
1189+ and be broadcast compatible.
1190+
1191+ Returns:
1192+ An array of the same dtype as ``x`` and ``y`` containing the remainder.
1193+
1194+ See also:
1195+ - :func:`jax.numpy.remainder`: NumPy-style remainder with different
1196+ sign semantics.
1197+
1198+ .. _stablehlo.remainder: https://openxla.org/stablehlo/spec#remainder
11371199 """
11381200 return rem_p .bind (x , y )
11391201
1202+ @export
11401203def max (x : ArrayLike , y : ArrayLike ) -> Array :
1141- r"""Elementwise maximum: :math:`\mathrm{max}(x, y)`
1204+ r"""Elementwise maximum: :math:`\mathrm{max}(x, y)`.
1205+
1206+ This function lowers directly to the `stablehlo.maximum`_ operation for
1207+ non-complex inputs. For complex numbers, this uses a lexicographic
1208+ comparison on the `(real, imaginary)` pairs.
1209+
1210+ Args:
1211+ x, y: Input arrays. Must have matching dtypes. If neither is a scalar,
1212+ ``x`` and ``y`` must have the same rank and be broadcast compatible.
11421213
1143- For complex numbers, uses a lexicographic comparison on the
1144- `(real, imaginary)` pairs."""
1214+ Returns:
1215+ An array of the same dtype as ``x`` and ``y`` containing the elementwise
1216+ maximum.
1217+
1218+ See also:
1219+ - :func:`jax.numpy.maximum`: more flexibly NumPy-style maximum.
1220+ - :func:`jax.lax.reduce_max`: maximum along an axis of an array.
1221+ - :func:`jax.lax.min`: elementwise minimum.
1222+
1223+ .. _stablehlo.maximum: https://openxla.org/stablehlo/spec#maximum
1224+ """
11451225 return max_p .bind (x , y )
11461226
1227+ @export
11471228def min (x : ArrayLike , y : ArrayLike ) -> Array :
1148- r"""Elementwise minimum: :math:`\mathrm{min}(x, y)`
1229+ r"""Elementwise minimum: :math:`\mathrm{min}(x, y)`
1230+
1231+ This function lowers directly to the `stablehlo.minimum`_ operation for
1232+ non-complex inputs. For complex numbers, this uses a lexicographic
1233+ comparison on the `(real, imaginary)` pairs.
1234+
1235+ Args:
1236+ x, y: Input arrays. Must have matching dtypes. If neither is a scalar,
1237+ ``x`` and ``y`` must have the same rank and be broadcast compatible.
11491238
1150- For complex numbers, uses a lexicographic comparison on the
1151- `(real, imaginary)` pairs."""
1239+ Returns:
1240+ An array of the same dtype as ``x`` and ``y`` containing the elementwise
1241+ minimum.
1242+
1243+ See also:
1244+ - :func:`jax.numpy.minimum`: more flexibly NumPy-style minimum.
1245+ - :func:`jax.lax.reduce_min`: minimum along an axis of an array.
1246+ - :func:`jax.lax.max`: elementwise maximum.
1247+
1248+ .. _stablehlo.minimum: https://openxla.org/stablehlo/spec#minimum
1249+ """
11521250 return min_p .bind (x , y )
11531251
11541252@export
@@ -1408,21 +1506,38 @@ def lt(x: ArrayLike, y: ArrayLike) -> Array:
14081506 """
14091507 return lt_p .bind (x , y )
14101508
1509+ @export
14111510def convert_element_type (operand : ArrayLike ,
14121511 new_dtype : DTypeLike | dtypes .ExtendedDType ) -> Array :
14131512 """Elementwise cast.
14141513
1415- Wraps XLA's `ConvertElementType
1416- <https://www.tensorflow.org/xla/operation_semantics#convertelementtype>`_
1417- operator, which performs an elementwise conversion from one type to another.
1418- Similar to a C++ `static_cast`.
1514+ This function lowers directly to the `stablehlo.convert`_ operation, which
1515+ performs an elementwise conversion from one type to another, similar to a
1516+ C++ ``static_cast``.
14191517
14201518 Args:
14211519 operand: an array or scalar value to be cast.
1422- new_dtype: a NumPy dtype representing the target type.
1520+ new_dtype: a dtype-like object (e.g. a :class:`numpy.dtype`, a scalar type,
1521+ or a valid dtype name) representing the target dtype.
14231522
14241523 Returns:
1425- An array with the same shape as `operand`, cast elementwise to `new_dtype`.
1524+ An array with the same shape as ``operand``, cast elementwise to ``new_dtype``.
1525+
1526+ .. note::
1527+
1528+ If ``new_dtype`` is a 64-bit type and `x64 mode`_ is not enabled,
1529+ the appropriate 32-bit type will be used in its place.
1530+
1531+ If the input is a JAX array and the input dtype and output dtype match, then
1532+ the input array will be returned unmodified.
1533+
1534+ See also:
1535+ - :func:`jax.numpy.astype`: NumPy-style dtype casting API.
1536+ - :meth:`jax.Array.astype`: dtype casting as an array method.
1537+ - :func:`jax.lax.bitcast_convert_type`: cast bits directly to a new dtype.
1538+
1539+ .. _stablehlo.convert: https://openxla.org/stablehlo/spec#convert
1540+ .. _x64 mode: https://docs.jax.dev/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision
14261541 """
14271542 return _convert_element_type (operand , new_dtype , weak_type = False ) # type: ignore[unused-ignore,bad-return-type]
14281543
@@ -1500,12 +1615,11 @@ def _convert_element_type(
15001615 operand , new_dtype = new_dtype , weak_type = bool (weak_type ),
15011616 sharding = sharding )
15021617
1618+ @export
15031619def bitcast_convert_type (operand : ArrayLike , new_dtype : DTypeLike ) -> Array :
15041620 """Elementwise bitcast.
15051621
1506- Wraps XLA's `BitcastConvertType
1507- <https://www.tensorflow.org/xla/operation_semantics#bitcastconverttype>`_
1508- operator, which performs a bit cast from one type to another.
1622+ This function lowers directly to the `stablehlo.bitcast_convert`_ operation.
15091623
15101624 The output shape depends on the size of the input and output dtypes with
15111625 the following logic::
@@ -1525,6 +1639,12 @@ def bitcast_convert_type(operand: ArrayLike, new_dtype: DTypeLike) -> Array:
15251639 Returns:
15261640 An array of shape `output_shape` (see above) and type `new_dtype`,
15271641 constructed from the same bits as operand.
1642+
1643+ See also:
1644+ - :func:`jax.lax.convert_element_type`: value-preserving dtype conversion.
1645+ - :func:`jax.Array.view`: NumPy-style API for bitcast type conversion.
1646+
1647+ .. _stablehlo.bitcast_convert: https://openxla.org/stablehlo/spec#bitcast_convert
15281648 """
15291649 new_dtype = dtypes .canonicalize_dtype (new_dtype )
15301650 return bitcast_convert_type_p .bind (operand , new_dtype = new_dtype )
0 commit comments