@@ -253,9 +253,51 @@ def bitwise_count(x: ArrayLike, /) -> Array:
253
253
# Following numpy we take the absolute value and return uint8.
254
254
return lax .population_count (abs (x )).astype ('uint8' )
255
255
256
- @implements (np .right_shift , module = 'numpy' )
257
256
@partial (jit , inline = True )
258
257
def right_shift (x1 : ArrayLike , x2 : ArrayLike , / ) -> Array :
258
+ r"""Right shift the bits of ``x1`` to the amount specified in ``x2``.
259
+
260
+ LAX-backend implementation of :func:`numpy.right_shift`.
261
+
262
+ Args:
263
+ x1: Input array, only accepts unsigned integer subtypes
264
+ x2: The amount of bits to shift each element in ``x1`` to the right, only accepts
265
+ integer subtypes
266
+
267
+ Returns:
268
+ An array-like object containing the right shifted elements of ``x1`` by the
269
+ amount specified in ``x2``, with the same shape as the broadcasted shape of
270
+ ``x1`` and ``x2``.
271
+
272
+ Note:
273
+ If ``x1.shape != x2.shape``, they must be compatible for broadcasting to a
274
+ shared shape, this shared shape will also be the shape of the output. Right shifting
275
+ a scalar x1 by scalar x2 is equivalent to ``x1 // 2**x2``.
276
+
277
+ Example:
278
+ >>> def print_binary(x):
279
+ ... return [bin(int(val)) for val in x]
280
+
281
+ >>> x1 = jnp.array([1, 2, 4, 8])
282
+ >>> print_binary(x1)
283
+ ['0b1', '0b10', '0b100', '0b1000']
284
+ >>> x2 = 1
285
+ >>> result = jnp.right_shift(x1, x2)
286
+ >>> result
287
+ Array([0, 1, 2, 4], dtype=int32)
288
+ >>> print_binary(result)
289
+ ['0b0', '0b1', '0b10', '0b100']
290
+
291
+ >>> x1 = 16
292
+ >>> print_binary([x1])
293
+ ['0b10000']
294
+ >>> x2 = jnp.array([1, 2, 3, 4])
295
+ >>> result = jnp.right_shift(x1, x2)
296
+ >>> result
297
+ Array([8, 4, 2, 1], dtype=int32)
298
+ >>> print_binary(result)
299
+ ['0b1000', '0b100', '0b10', '0b1']
300
+ """
259
301
x1 , x2 = promote_args_numeric (np .right_shift .__name__ , x1 , x2 )
260
302
lax_fn = lax .shift_right_logical if \
261
303
np .issubdtype (x1 .dtype , np .unsignedinteger ) else lax .shift_right_arithmetic
0 commit comments