Skip to content

Commit 6e1aa3c

Browse files
committed
Specialize ufunc.reduce for monoidal binary ufuncs.
1 parent 3a5ac48 commit 6e1aa3c

File tree

2 files changed

+69
-25
lines changed

2 files changed

+69
-25
lines changed

jax/_src/numpy/reductions.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,11 @@ def _cast_to_bool(operand: ArrayLike) -> Array:
192192
def _cast_to_numeric(operand: ArrayLike) -> Array:
193193
return promote_dtypes_numeric(operand)[0]
194194

195+
def _require_integer(operand: ArrayLike) -> Array:
196+
arr = lax_internal.asarray(operand)
197+
if not dtypes.isdtype(arr, ("bool", "integral")):
198+
raise ValueError(f"integer argument required; got dtype={arr.dtype}")
199+
return arr
195200

196201
def _ensure_optional_axes(x: Axis) -> Axis:
197202
def force(x):
@@ -652,6 +657,63 @@ def any(a: ArrayLike, axis: Axis = None, out: None = None,
652657
return _reduce_any(a, axis=_ensure_optional_axes(axis), out=out,
653658
keepdims=keepdims, where=where)
654659

660+
661+
@partial(api.jit, static_argnames=('axis', 'keepdims', 'dtype'), inline=True)
662+
def _reduce_bitwise_and(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
663+
out: None = None, keepdims: bool = False,
664+
initial: ArrayLike | None = None, where: ArrayLike | None = None) -> Array:
665+
arr = lax_internal.asarray(a)
666+
init_val = np.array(-1, dtype=dtype or arr.dtype)
667+
return _reduction(arr, "reduce_bitwise_and", None, lax.bitwise_and, init_val, preproc=_require_integer,
668+
axis=_ensure_optional_axes(axis), dtype=dtype, out=out, keepdims=keepdims,
669+
initial=initial, where_=where)
670+
671+
672+
@partial(api.jit, static_argnames=('axis', 'keepdims', 'dtype'), inline=True)
673+
def _reduce_bitwise_or(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
674+
out: None = None, keepdims: bool = False,
675+
initial: ArrayLike | None = None, where: ArrayLike | None = None) -> Array:
676+
return _reduction(a, "reduce_bitwise_or", None, lax.bitwise_or, 0, preproc=_require_integer,
677+
axis=_ensure_optional_axes(axis), dtype=dtype, out=out, keepdims=keepdims,
678+
initial=initial, where_=where)
679+
680+
681+
@partial(api.jit, static_argnames=('axis', 'keepdims', 'dtype'), inline=True)
682+
def _reduce_bitwise_xor(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
683+
out: None = None, keepdims: bool = False,
684+
initial: ArrayLike | None = None, where: ArrayLike | None = None) -> Array:
685+
return _reduction(a, "reduce_bitwise_xor", None, lax.bitwise_xor, 0, preproc=_require_integer,
686+
axis=_ensure_optional_axes(axis), dtype=dtype, out=out, keepdims=keepdims,
687+
initial=initial, where_=where)
688+
689+
690+
@partial(api.jit, static_argnames=('axis', 'keepdims', 'dtype'), inline=True)
691+
def _reduce_logical_and(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
692+
out: None = None, keepdims: bool = False,
693+
initial: ArrayLike | None = None, where: ArrayLike | None = None) -> Array:
694+
return _reduction(a, "reduce_logical_and", None, lax.bitwise_and, True, preproc=_cast_to_bool,
695+
axis=_ensure_optional_axes(axis), dtype=dtype, out=out, keepdims=keepdims,
696+
initial=initial, where_=where)
697+
698+
699+
@partial(api.jit, static_argnames=('axis', 'keepdims', 'dtype'), inline=True)
700+
def _reduce_logical_or(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
701+
out: None = None, keepdims: bool = False,
702+
initial: ArrayLike | None = None, where: ArrayLike | None = None) -> Array:
703+
return _reduction(a, "reduce_logical_or", None, lax.bitwise_or, False, preproc=_cast_to_bool,
704+
axis=_ensure_optional_axes(axis), dtype=dtype, out=out, keepdims=keepdims,
705+
initial=initial, where_=where)
706+
707+
708+
@partial(api.jit, static_argnames=('axis', 'keepdims', 'dtype'), inline=True)
709+
def _reduce_logical_xor(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
710+
out: None = None, keepdims: bool = False,
711+
initial: ArrayLike | None = None, where: ArrayLike | None = None) -> Array:
712+
return _reduction(a, "reduce_logical_xor", None, lax.bitwise_xor, False, preproc=_cast_to_bool,
713+
axis=_ensure_optional_axes(axis), dtype=dtype, out=out, keepdims=keepdims,
714+
initial=initial, where_=where)
715+
716+
655717
def amin(a: ArrayLike, axis: Axis = None, out: None = None,
656718
keepdims: bool = False, initial: ArrayLike | None = None,
657719
where: ArrayLike | None = None) -> Array:

jax/_src/numpy/ufuncs.py

Lines changed: 7 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
from jax._src.custom_derivatives import custom_jvp
3232
from jax._src.lax import lax
3333
from jax._src.lax import other as lax_other
34-
from jax._src.typing import Array, ArrayLike, DTypeLike
34+
from jax._src.typing import Array, ArrayLike
3535
from jax._src.numpy.util import (
3636
check_arraylike, promote_args, promote_args_inexact,
3737
promote_args_numeric, promote_dtypes_inexact, promote_dtypes_numeric,
@@ -1221,7 +1221,7 @@ def multiply(x: ArrayLike, y: ArrayLike, /) -> Array:
12211221
x, y = promote_args("multiply", x, y)
12221222
return lax.mul(x, y) if x.dtype != bool else lax.bitwise_and(x, y)
12231223

1224-
@binary_ufunc(identity=-1)
1224+
@binary_ufunc(identity=-1, reduce=reductions._reduce_bitwise_and)
12251225
def bitwise_and(x: ArrayLike, y: ArrayLike, /) -> Array:
12261226
"""Compute the bitwise AND operation elementwise.
12271227
@@ -1250,7 +1250,7 @@ def bitwise_and(x: ArrayLike, y: ArrayLike, /) -> Array:
12501250
"""
12511251
return lax.bitwise_and(*promote_args("bitwise_and", x, y))
12521252

1253-
@binary_ufunc(identity=0)
1253+
@binary_ufunc(identity=0, reduce=reductions._reduce_bitwise_or)
12541254
def bitwise_or(x: ArrayLike, y: ArrayLike, /) -> Array:
12551255
"""Compute the bitwise OR operation elementwise.
12561256
@@ -1279,7 +1279,7 @@ def bitwise_or(x: ArrayLike, y: ArrayLike, /) -> Array:
12791279
"""
12801280
return lax.bitwise_or(*promote_args("bitwise_or", x, y))
12811281

1282-
@binary_ufunc(identity=0)
1282+
@binary_ufunc(identity=0, reduce=reductions._reduce_bitwise_xor)
12831283
def bitwise_xor(x: ArrayLike, y: ArrayLike, /) -> Array:
12841284
"""Compute the bitwise XOR operation elementwise.
12851285
@@ -1793,16 +1793,7 @@ def spacing(x: ArrayLike, /) -> Array:
17931793

17941794

17951795
# Logical ops
1796-
def _logical_and_reduce(a: ArrayLike, axis: int = 0, dtype: DTypeLike | None = None,
1797-
out: None = None, keepdims: bool = False, initial: ArrayLike | None = None,
1798-
where: ArrayLike | None = None):
1799-
"""Implementation of jnp.logical_and.reduce."""
1800-
if initial is not None:
1801-
raise ValueError("initial argument not supported in jnp.logical_and.reduce()")
1802-
result = reductions.all(a, axis=axis, out=out, keepdims=keepdims, where=where)
1803-
return result if dtype is None else result.astype(dtype)
1804-
1805-
@binary_ufunc(identity=True, reduce=_logical_and_reduce)
1796+
@binary_ufunc(identity=True, reduce=reductions._reduce_logical_and)
18061797
def logical_and(x: ArrayLike, y: ArrayLike, /) -> Array:
18071798
"""Compute the logical AND operation elementwise.
18081799
@@ -1823,16 +1814,7 @@ def logical_and(x: ArrayLike, y: ArrayLike, /) -> Array:
18231814
return lax.bitwise_and(*map(_to_bool, promote_args("logical_and", x, y)))
18241815

18251816

1826-
def _logical_or_reduce(a: ArrayLike, axis: int = 0, dtype: DTypeLike | None = None,
1827-
out: None = None, keepdims: bool = False, initial: ArrayLike | None = None,
1828-
where: ArrayLike | None = None):
1829-
"""Implementation of jnp.logical_or.reduce."""
1830-
if initial is not None:
1831-
raise ValueError("initial argument not supported in jnp.logical_or.reduce()")
1832-
result = reductions.any(a, axis=axis, out=out, keepdims=keepdims, where=where)
1833-
return result if dtype is None else result.astype(dtype)
1834-
1835-
@binary_ufunc(identity=False, reduce=_logical_or_reduce)
1817+
@binary_ufunc(identity=False, reduce=reductions._reduce_logical_or)
18361818
def logical_or(x: ArrayLike, y: ArrayLike, /) -> Array:
18371819
"""Compute the logical OR operation elementwise.
18381820
@@ -1853,7 +1835,7 @@ def logical_or(x: ArrayLike, y: ArrayLike, /) -> Array:
18531835
return lax.bitwise_or(*map(_to_bool, promote_args("logical_or", x, y)))
18541836

18551837

1856-
@binary_ufunc(identity=False)
1838+
@binary_ufunc(identity=False, reduce=reductions._reduce_logical_xor)
18571839
def logical_xor(x: ArrayLike, y: ArrayLike, /) -> Array:
18581840
"""Compute the logical XOR operation elementwise.
18591841

0 commit comments

Comments
 (0)