Skip to content

Commit 8b46e53

Browse files
committed
jax.lax: improve docs for several APIs
1 parent 1e36cbe commit 8b46e53

File tree

1 file changed

+143
-23
lines changed

1 file changed

+143
-23
lines changed

jax/_src/lax/lax.py

Lines changed: 143 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -615,8 +615,23 @@ def tanh(x: ArrayLike) -> Array:
615615
"""
616616
return tanh_p.bind(x)
617617

618+
@export
618619
def logistic(x: ArrayLike) -> Array:
619-
r"""Elementwise logistic (sigmoid) function: :math:`\frac{1}{1 + e^{-x}}`."""
620+
r"""Elementwise logistic (sigmoid) function: :math:`\frac{1}{1 + e^{-x}}`.
621+
622+
There is no HLO logistic/sigmoid primitive, so this lowers to a sequence
623+
of HLO arithmetic operations.
624+
625+
Args:
626+
x: input array. Must have floating point or complex dtype.
627+
628+
Returns:
629+
Array of the same shape and dtype as ``x`` containing the element-wise
630+
logistic/sigmoid function.
631+
632+
See also:
633+
- :func:`jax.nn.sigmoid`: an alternative API for this functionality.
634+
"""
620635
return logistic_p.bind(x)
621636

622637
@export
@@ -1018,12 +1033,45 @@ def bitwise_xor(x: ArrayLike, y: ArrayLike) -> Array:
10181033
"""
10191034
return xor_p.bind(x, y)
10201035

1036+
@export
10211037
def population_count(x: ArrayLike) -> Array:
1022-
r"""Elementwise popcount, count the number of set bits in each element."""
1038+
r"""Elementwise popcount, count the number of set bits in each element.
1039+
1040+
This function lowers directly to the `stablehlo.popcnt`_ operation.
1041+
1042+
Args:
1043+
x: Input array. Must have integer dtype.
1044+
1045+
Returns:
1046+
An array of the same shape and dtype as ``x``, containing the number of
1047+
set bits in the input.
1048+
1049+
See also:
1050+
- :func:`jax.lax.clz`: Elementwise count leading zeros.
1051+
- :func:`jax.numpy.bitwise_count`: More flexible NumPy-style API for bit counts.
1052+
1053+
.. _stablehlo.popcnt: https://openxla.org/stablehlo/spec#popcnt
1054+
"""
10231055
return population_count_p.bind(x)
10241056

1057+
@export
10251058
def clz(x: ArrayLike) -> Array:
1026-
r"""Elementwise count-leading-zeros."""
1059+
r"""Elementwise count-leading-zeros.
1060+
1061+
This function lowers directly to the `stablehlo.count_leading_zeros`_ operation.
1062+
1063+
Args:
1064+
x: Input array. Must have integer dtype.
1065+
1066+
Returns:
1067+
An array of the same shape and dtype as ``x``, containing the number of
1068+
set bits in the input.
1069+
1070+
See also:
1071+
- :func:`jax.lax.population_count`: Count the number of set bits in each element.
1072+
1073+
.. _stablehlo.count_leading_zeros: https://openxla.org/stablehlo/spec#count_leading_zeros
1074+
"""
10271075
return clz_p.bind(x)
10281076

10291077
@export
@@ -1124,31 +1172,81 @@ def div(x: ArrayLike, y: ArrayLike) -> Array:
11241172
"""
11251173
return div_p.bind(x, y)
11261174

1175+
@export
11271176
def rem(x: ArrayLike, y: ArrayLike) -> Array:
11281177
r"""Elementwise remainder: :math:`x \bmod y`.
11291178
1130-
The sign of the result is taken from the dividend,
1131-
and the absolute value of the result is always
1132-
less than the divisor's absolute value.
1179+
This function lowers directly to the `stablehlo.remainder`_ operation.
1180+
The sign of the result is taken from the dividend, and the absolute value
1181+
of the result is always less than the divisor's absolute value.
11331182
1134-
Integer division overflow
1135-
(remainder by zero or remainder of INT_SMIN with -1)
1183+
Integer division overflow (remainder by zero or remainder of INT_SMIN with -1)
11361184
produces an implementation defined value.
1185+
1186+
Args:
1187+
x, y: Input arrays. Must have matching int or float dtypes. If neither
1188+
is a scalar, ``x`` and ``y`` must have the same number of dimensions
1189+
and be broadcast compatible.
1190+
1191+
Returns:
1192+
An array of the same dtype as ``x`` and ``y`` containing the remainder.
1193+
1194+
See also:
1195+
- :func:`jax.numpy.remainder`: NumPy-style remainder with different
1196+
sign semantics.
1197+
1198+
.. _stablehlo.remainder: https://openxla.org/stablehlo/spec#remainder
11371199
"""
11381200
return rem_p.bind(x, y)
11391201

1202+
@export
11401203
def max(x: ArrayLike, y: ArrayLike) -> Array:
1141-
r"""Elementwise maximum: :math:`\mathrm{max}(x, y)`
1204+
r"""Elementwise maximum: :math:`\mathrm{max}(x, y)`.
1205+
1206+
This function lowers directly to the `stablehlo.maximum`_ operation for
1207+
non-complex inputs. For complex numbers, this uses a lexicographic
1208+
comparison on the `(real, imaginary)` pairs.
1209+
1210+
Args:
1211+
x, y: Input arrays. Must have matching dtypes. If neither is a scalar,
1212+
``x`` and ``y`` must have the same rank and be broadcast compatible.
11421213
1143-
For complex numbers, uses a lexicographic comparison on the
1144-
`(real, imaginary)` pairs."""
1214+
Returns:
1215+
An array of the same dtype as ``x`` and ``y`` containing the elementwise
1216+
maximum.
1217+
1218+
See also:
1219+
- :func:`jax.numpy.maximum`: more flexibly NumPy-style maximum.
1220+
- :func:`jax.lax.reduce_max`: maximum along an axis of an array.
1221+
- :func:`jax.lax.min`: elementwise minimum.
1222+
1223+
.. _stablehlo.maximum: https://openxla.org/stablehlo/spec#maximum
1224+
"""
11451225
return max_p.bind(x, y)
11461226

1227+
@export
11471228
def min(x: ArrayLike, y: ArrayLike) -> Array:
1148-
r"""Elementwise minimum: :math:`\mathrm{min}(x, y)`
1229+
r"""Elementwise minimum: :math:`\mathrm{min}(x, y)`
1230+
1231+
This function lowers directly to the `stablehlo.minimum`_ operation for
1232+
non-complex inputs. For complex numbers, this uses a lexicographic
1233+
comparison on the `(real, imaginary)` pairs.
1234+
1235+
Args:
1236+
x, y: Input arrays. Must have matching dtypes. If neither is a scalar,
1237+
``x`` and ``y`` must have the same rank and be broadcast compatible.
11491238
1150-
For complex numbers, uses a lexicographic comparison on the
1151-
`(real, imaginary)` pairs."""
1239+
Returns:
1240+
An array of the same dtype as ``x`` and ``y`` containing the elementwise
1241+
minimum.
1242+
1243+
See also:
1244+
- :func:`jax.numpy.minimum`: more flexibly NumPy-style minimum.
1245+
- :func:`jax.lax.reduce_min`: minimum along an axis of an array.
1246+
- :func:`jax.lax.max`: elementwise maximum.
1247+
1248+
.. _stablehlo.minimum: https://openxla.org/stablehlo/spec#minimum
1249+
"""
11521250
return min_p.bind(x, y)
11531251

11541252
@export
@@ -1408,21 +1506,38 @@ def lt(x: ArrayLike, y: ArrayLike) -> Array:
14081506
"""
14091507
return lt_p.bind(x, y)
14101508

1509+
@export
14111510
def convert_element_type(operand: ArrayLike,
14121511
new_dtype: DTypeLike | dtypes.ExtendedDType) -> Array:
14131512
"""Elementwise cast.
14141513
1415-
Wraps XLA's `ConvertElementType
1416-
<https://www.tensorflow.org/xla/operation_semantics#convertelementtype>`_
1417-
operator, which performs an elementwise conversion from one type to another.
1418-
Similar to a C++ `static_cast`.
1514+
This function lowers directly to the `stablehlo.convert`_ operation, which
1515+
performs an elementwise conversion from one type to another, similar to a
1516+
C++ ``static_cast``.
14191517
14201518
Args:
14211519
operand: an array or scalar value to be cast.
1422-
new_dtype: a NumPy dtype representing the target type.
1520+
new_dtype: a dtype-like object (e.g. a :class:`numpy.dtype`, a scalar type,
1521+
or a valid dtype name) representing the target dtype.
14231522
14241523
Returns:
1425-
An array with the same shape as `operand`, cast elementwise to `new_dtype`.
1524+
An array with the same shape as ``operand``, cast elementwise to ``new_dtype``.
1525+
1526+
.. note::
1527+
1528+
If ``new_dtype`` is a 64-bit type and `x64 mode`_ is not enabled,
1529+
the appropriate 32-bit type will be used in its place.
1530+
1531+
If the input is a JAX array and the input dtype and output dtype match, then
1532+
the input array will be returned unmodified.
1533+
1534+
See also:
1535+
- :func:`jax.numpy.astype`: NumPy-style dtype casting API.
1536+
- :meth:`jax.Array.astype`: dtype casting as an array method.
1537+
- :func:`jax.lax.bitcast_convert_type`: cast bits directly to a new dtype.
1538+
1539+
.. _stablehlo.convert: https://openxla.org/stablehlo/spec#convert
1540+
.. _x64 mode: https://docs.jax.dev/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision
14261541
"""
14271542
return _convert_element_type(operand, new_dtype, weak_type=False) # type: ignore[unused-ignore,bad-return-type]
14281543

@@ -1500,12 +1615,11 @@ def _convert_element_type(
15001615
operand, new_dtype=new_dtype, weak_type=bool(weak_type),
15011616
sharding=sharding)
15021617

1618+
@export
15031619
def bitcast_convert_type(operand: ArrayLike, new_dtype: DTypeLike) -> Array:
15041620
"""Elementwise bitcast.
15051621
1506-
Wraps XLA's `BitcastConvertType
1507-
<https://www.tensorflow.org/xla/operation_semantics#bitcastconverttype>`_
1508-
operator, which performs a bit cast from one type to another.
1622+
This function lowers directly to the `stablehlo.bitcast_convert`_ operation.
15091623
15101624
The output shape depends on the size of the input and output dtypes with
15111625
the following logic::
@@ -1525,6 +1639,12 @@ def bitcast_convert_type(operand: ArrayLike, new_dtype: DTypeLike) -> Array:
15251639
Returns:
15261640
An array of shape `output_shape` (see above) and type `new_dtype`,
15271641
constructed from the same bits as operand.
1642+
1643+
See also:
1644+
- :func:`jax.lax.convert_element_type`: value-preserving dtype conversion.
1645+
- :func:`jax.Array.view`: NumPy-style API for bitcast type conversion.
1646+
1647+
.. _stablehlo.bitcast_convert: https://openxla.org/stablehlo/spec#bitcast_convert
15281648
"""
15291649
new_dtype = dtypes.canonicalize_dtype(new_dtype)
15301650
return bitcast_convert_type_p.bind(operand, new_dtype=new_dtype)

0 commit comments

Comments
 (0)