Skip to content

Commit f6ce973

Browse files
author
jax authors
committed
Merge pull request #21745 from pkgoogle:better_right_shift_doc
PiperOrigin-RevId: 641972495
2 parents a073476 + 07d90e5 commit f6ce973

File tree

1 file changed

+43
-1
lines changed

1 file changed

+43
-1
lines changed

jax/_src/numpy/ufuncs.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -253,9 +253,51 @@ def bitwise_count(x: ArrayLike, /) -> Array:
253253
# Following numpy we take the absolute value and return uint8.
254254
return lax.population_count(abs(x)).astype('uint8')
255255

256-
@implements(np.right_shift, module='numpy')
257256
@partial(jit, inline=True)
258257
def right_shift(x1: ArrayLike, x2: ArrayLike, /) -> Array:
258+
r"""Right shift the bits of ``x1`` to the amount specified in ``x2``.
259+
260+
LAX-backend implementation of :func:`numpy.right_shift`.
261+
262+
Args:
263+
x1: Input array, only accepts unsigned integer subtypes
264+
x2: The amount of bits to shift each element in ``x1`` to the right, only accepts
265+
integer subtypes
266+
267+
Returns:
268+
An array-like object containing the right shifted elements of ``x1`` by the
269+
amount specified in ``x2``, with the same shape as the broadcasted shape of
270+
``x1`` and ``x2``.
271+
272+
Note:
273+
If ``x1.shape != x2.shape``, they must be compatible for broadcasting to a
274+
shared shape, this shared shape will also be the shape of the output. Right shifting
275+
a scalar x1 by scalar x2 is equivalent to ``x1 // 2**x2``.
276+
277+
Example:
278+
>>> def print_binary(x):
279+
... return [bin(int(val)) for val in x]
280+
281+
>>> x1 = jnp.array([1, 2, 4, 8])
282+
>>> print_binary(x1)
283+
['0b1', '0b10', '0b100', '0b1000']
284+
>>> x2 = 1
285+
>>> result = jnp.right_shift(x1, x2)
286+
>>> result
287+
Array([0, 1, 2, 4], dtype=int32)
288+
>>> print_binary(result)
289+
['0b0', '0b1', '0b10', '0b100']
290+
291+
>>> x1 = 16
292+
>>> print_binary([x1])
293+
['0b10000']
294+
>>> x2 = jnp.array([1, 2, 3, 4])
295+
>>> result = jnp.right_shift(x1, x2)
296+
>>> result
297+
Array([8, 4, 2, 1], dtype=int32)
298+
>>> print_binary(result)
299+
['0b1000', '0b100', '0b10', '0b1']
300+
"""
259301
x1, x2 = promote_args_numeric(np.right_shift.__name__, x1, x2)
260302
lax_fn = lax.shift_right_logical if \
261303
np.issubdtype(x1.dtype, np.unsignedinteger) else lax.shift_right_arithmetic

0 commit comments

Comments
 (0)