Skip to content

Commit 9b56215

Browse files
committed
Internal: create decorators for defining ufuncs
1 parent f18f62a commit 9b56215

File tree

2 files changed

+81
-76
lines changed

2 files changed

+81
-76
lines changed

jax/_src/numpy/ufuncs.py

Lines changed: 79 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,24 @@ def _to_bool(x: Array) -> Array:
5757
return x if x.dtype == bool else lax.ne(x, _lax_const(x, 0))
5858

5959

60+
def unary_ufunc(func: Callable[[ArrayLike], Array]) -> ufunc:
61+
"""An internal helper function for defining unary ufuncs."""
62+
func_jit = jit(func, inline=True)
63+
return ufunc(func_jit, name=func.__name__, nin=1, nout=1, call=func_jit)
64+
65+
66+
def binary_ufunc(identity: Any, reduce: Callable[..., Any] | None = None,
67+
accumulate: Callable[..., Any] | None = None,
68+
at: Callable[..., Any] | None = None,
69+
reduceat: Callable[..., Any] | None = None) -> Callable[[Callable[[ArrayLike, ArrayLike], Array]], ufunc]:
70+
"""An internal helper function for defining binary ufuncs."""
71+
def decorator(func: Callable[[ArrayLike, ArrayLike], Array]) -> ufunc:
72+
func_jit = jit(func, inline=True)
73+
return ufunc(func_jit, name=func.__name__, nin=2, nout=1, call=func_jit,
74+
identity=identity, reduce=reduce, accumulate=accumulate, at=at, reduceat=reduceat)
75+
return decorator
76+
77+
6078
@partial(jit, inline=True)
6179
def fabs(x: ArrayLike, /) -> Array:
6280
"""Compute the element-wise absolute values of the real-valued input.
@@ -160,8 +178,8 @@ def invert(x: ArrayLike, /) -> Array:
160178
return lax.bitwise_not(*promote_args('invert', x))
161179

162180

163-
@partial(jit, inline=True)
164-
def _negative(x: ArrayLike, /) -> Array:
181+
@unary_ufunc
182+
def negative(x: ArrayLike, /) -> Array:
165183
"""Return element-wise negative values of the input.
166184
167185
JAX implementation of :obj:`numpy.negative`.
@@ -1126,8 +1144,16 @@ def cbrt(x: ArrayLike, /) -> Array:
11261144
"""
11271145
return lax.cbrt(*promote_args_inexact('cbrt', x))
11281146

1129-
@partial(jit, inline=True)
1130-
def _add(x: ArrayLike, y: ArrayLike, /) -> Array:
1147+
def _add_at(a: Array, indices: Any, b: ArrayLike) -> Array:
1148+
"""Implementation of jnp.add.at."""
1149+
if a.dtype == bool:
1150+
a = a.astype('int32')
1151+
b = lax.convert_element_type(b, bool).astype('int32')
1152+
return a.at[indices].add(b).astype(bool)
1153+
return a.at[indices].add(b)
1154+
1155+
@binary_ufunc(identity=0, reduce=reductions.sum, accumulate=reductions.cumsum, at=_add_at)
1156+
def add(x: ArrayLike, y: ArrayLike, /) -> Array:
11311157
"""Add two arrays element-wise.
11321158
11331159
JAX implementation of :obj:`numpy.add`. This is a universal function,
@@ -1156,8 +1182,17 @@ def _add(x: ArrayLike, y: ArrayLike, /) -> Array:
11561182
x, y = promote_args("add", x, y)
11571183
return lax.add(x, y) if x.dtype != bool else lax.bitwise_or(x, y)
11581184

1159-
@partial(jit, inline=True)
1160-
def _multiply(x: ArrayLike, y: ArrayLike, /) -> Array:
1185+
def _multiply_at(a: Array, indices: Any, b: ArrayLike) -> Array:
1186+
"""Implementation of jnp.multiply.at."""
1187+
if a.dtype == bool:
1188+
a = a.astype('int32')
1189+
b = lax.convert_element_type(b, bool).astype('int32')
1190+
return a.at[indices].mul(b).astype(bool)
1191+
else:
1192+
return a.at[indices].mul(b)
1193+
1194+
@binary_ufunc(identity=1, reduce=reductions.prod, accumulate=reductions.cumprod, at=_multiply_at)
1195+
def multiply(x: ArrayLike, y: ArrayLike, /) -> Array:
11611196
"""Multiply two arrays element-wise.
11621197
11631198
JAX implementation of :obj:`numpy.multiply`. This is a universal function,
@@ -1186,8 +1221,8 @@ def _multiply(x: ArrayLike, y: ArrayLike, /) -> Array:
11861221
x, y = promote_args("multiply", x, y)
11871222
return lax.mul(x, y) if x.dtype != bool else lax.bitwise_and(x, y)
11881223

1189-
@partial(jit, inline=True)
1190-
def _bitwise_and(x: ArrayLike, y: ArrayLike, /) -> Array:
1224+
@binary_ufunc(identity=-1)
1225+
def bitwise_and(x: ArrayLike, y: ArrayLike, /) -> Array:
11911226
"""Compute the bitwise AND operation elementwise.
11921227
11931228
JAX implementation of :obj:`numpy.bitwise_and`. This is a universal function,
@@ -1215,8 +1250,8 @@ def _bitwise_and(x: ArrayLike, y: ArrayLike, /) -> Array:
12151250
"""
12161251
return lax.bitwise_and(*promote_args("bitwise_and", x, y))
12171252

1218-
@partial(jit, inline=True)
1219-
def _bitwise_or(x: ArrayLike, y: ArrayLike, /) -> Array:
1253+
@binary_ufunc(identity=0)
1254+
def bitwise_or(x: ArrayLike, y: ArrayLike, /) -> Array:
12201255
"""Compute the bitwise OR operation elementwise.
12211256
12221257
JAX implementation of :obj:`numpy.bitwise_or`. This is a universal function,
@@ -1244,8 +1279,8 @@ def _bitwise_or(x: ArrayLike, y: ArrayLike, /) -> Array:
12441279
"""
12451280
return lax.bitwise_or(*promote_args("bitwise_or", x, y))
12461281

1247-
@partial(jit, inline=True)
1248-
def _bitwise_xor(x: ArrayLike, y: ArrayLike, /) -> Array:
1282+
@binary_ufunc(identity=0)
1283+
def bitwise_xor(x: ArrayLike, y: ArrayLike, /) -> Array:
12491284
"""Compute the bitwise XOR operation elementwise.
12501285
12511286
JAX implementation of :obj:`numpy.bitwise_xor`. This is a universal function,
@@ -1433,8 +1468,12 @@ def not_equal(x: ArrayLike, y: ArrayLike, /) -> Array:
14331468
return lax.ne(*promote_args("not_equal", x, y))
14341469

14351470

1436-
@partial(jit, inline=True)
1437-
def _subtract(x: ArrayLike, y: ArrayLike, /) -> Array:
1471+
def _subtract_at(a: Array, indices: Any, b: ArrayLike) -> Array:
1472+
"""Implementation of jnp.subtract.at."""
1473+
return a.at[indices].subtract(b)
1474+
1475+
@binary_ufunc(identity=None, at=_subtract_at)
1476+
def subtract(x: ArrayLike, y: ArrayLike, /) -> Array:
14381477
"""Subtract two arrays element-wise.
14391478
14401479
JAX implementation of :obj:`numpy.subtract`. This is a universal function,
@@ -1754,8 +1793,17 @@ def spacing(x: ArrayLike, /) -> Array:
17541793

17551794

17561795
# Logical ops
1757-
@partial(jit, inline=True)
1758-
def _logical_and(x: ArrayLike, y: ArrayLike, /) -> Array:
1796+
def _logical_and_reduce(a: ArrayLike, axis: int = 0, dtype: DTypeLike | None = None,
1797+
out: None = None, keepdims: bool = False, initial: ArrayLike | None = None,
1798+
where: ArrayLike | None = None):
1799+
"""Implementation of jnp.logical_and.reduce."""
1800+
if initial is not None:
1801+
raise ValueError("initial argument not supported in jnp.logical_and.reduce()")
1802+
result = reductions.all(a, axis=axis, out=out, keepdims=keepdims, where=where)
1803+
return result if dtype is None else result.astype(dtype)
1804+
1805+
@binary_ufunc(identity=True, reduce=_logical_and_reduce)
1806+
def logical_and(x: ArrayLike, y: ArrayLike, /) -> Array:
17591807
"""Compute the logical AND operation elementwise.
17601808
17611809
JAX implementation of :obj:`numpy.logical_and`. This is a universal function,
@@ -1774,8 +1822,18 @@ def _logical_and(x: ArrayLike, y: ArrayLike, /) -> Array:
17741822
"""
17751823
return lax.bitwise_and(*map(_to_bool, promote_args("logical_and", x, y)))
17761824

1777-
@partial(jit, inline=True)
1778-
def _logical_or(x: ArrayLike, y: ArrayLike, /) -> Array:
1825+
1826+
def _logical_or_reduce(a: ArrayLike, axis: int = 0, dtype: DTypeLike | None = None,
1827+
out: None = None, keepdims: bool = False, initial: ArrayLike | None = None,
1828+
where: ArrayLike | None = None):
1829+
"""Implementation of jnp.logical_or.reduce."""
1830+
if initial is not None:
1831+
raise ValueError("initial argument not supported in jnp.logical_or.reduce()")
1832+
result = reductions.any(a, axis=axis, out=out, keepdims=keepdims, where=where)
1833+
return result if dtype is None else result.astype(dtype)
1834+
1835+
@binary_ufunc(identity=False, reduce=_logical_or_reduce)
1836+
def logical_or(x: ArrayLike, y: ArrayLike, /) -> Array:
17791837
"""Compute the logical OR operation elementwise.
17801838
17811839
JAX implementation of :obj:`numpy.logical_or`. This is a universal function,
@@ -1794,8 +1852,9 @@ def _logical_or(x: ArrayLike, y: ArrayLike, /) -> Array:
17941852
"""
17951853
return lax.bitwise_or(*map(_to_bool, promote_args("logical_or", x, y)))
17961854

1797-
@partial(jit, inline=True)
1798-
def _logical_xor(x: ArrayLike, y: ArrayLike, /) -> Array:
1855+
1856+
@binary_ufunc(identity=False)
1857+
def logical_xor(x: ArrayLike, y: ArrayLike, /) -> Array:
17991858
"""Compute the logical XOR operation elementwise.
18001859
18011860
JAX implementation of :obj:`numpy.logical_xor`. This is a universal function,
@@ -3653,57 +3712,3 @@ def _sinc_maclaurin(k, x):
36533712
def _sinc_maclaurin_jvp(k, primals, tangents):
36543713
(x,), (t,) = primals, tangents
36553714
return _sinc_maclaurin(k, x), _sinc_maclaurin(k + 1, x) * t
3656-
3657-
3658-
def _logical_and_reduce(a: ArrayLike, axis: int = 0, dtype: DTypeLike | None = None,
3659-
out: None = None, keepdims: bool = False, initial: ArrayLike | None = None,
3660-
where: ArrayLike | None = None):
3661-
if initial is not None:
3662-
raise ValueError("initial argument not supported in jnp.logical_and.reduce()")
3663-
result = reductions.all(a, axis=axis, out=out, keepdims=keepdims, where=where)
3664-
return result if dtype is None else result.astype(dtype)
3665-
3666-
3667-
def _logical_or_reduce(a: ArrayLike, axis: int = 0, dtype: DTypeLike | None = None,
3668-
out: None = None, keepdims: bool = False, initial: ArrayLike | None = None,
3669-
where: ArrayLike | None = None):
3670-
if initial is not None:
3671-
raise ValueError("initial argument not supported in jnp.logical_or.reduce()")
3672-
result = reductions.any(a, axis=axis, out=out, keepdims=keepdims, where=where)
3673-
return result if dtype is None else result.astype(dtype)
3674-
3675-
def _add_at(a: Array, indices: Any, b: ArrayLike):
3676-
if a.dtype == bool:
3677-
a = a.astype('int32')
3678-
b = lax.convert_element_type(b, bool).astype('int32')
3679-
return a.at[indices].add(b).astype(bool)
3680-
return a.at[indices].add(b)
3681-
3682-
def _subtract_at(a: Array, indices: Any, b: ArrayLike):
3683-
return a.at[indices].subtract(b)
3684-
3685-
def _multiply_at(a: Array, indices: Any, b: ArrayLike):
3686-
if a.dtype == bool:
3687-
a = a.astype('int32')
3688-
b = lax.convert_element_type(b, bool).astype('int32')
3689-
return a.at[indices].mul(b).astype(bool)
3690-
else:
3691-
return a.at[indices].mul(b)
3692-
3693-
# Generate ufunc interfaces for several common binary functions.
3694-
# We start with binary ufuncs that have well-defined identities.'
3695-
# TODO(jakevdp): wrap more ufuncs. Possibly define a decorator for convenience?
3696-
# TODO(jakevdp): optimize some implementations.
3697-
# - define add.at/multiply.at in terms of scatter_add/scatter_mul
3698-
# - define add.reduceat/multiply.reduceat in terms of segment_sum/segment_prod
3699-
# - define all monoidal reductions in terms of lax.reduce
3700-
add = ufunc(_add, name="add", nin=2, nout=1, identity=0, call=_add, reduce=reductions.sum, accumulate=reductions.cumsum, at=_add_at)
3701-
multiply = ufunc(_multiply, name="multiply", nin=2, nout=1, identity=1, call=_multiply, reduce=reductions.prod, accumulate=reductions.cumprod, at=_multiply_at)
3702-
bitwise_and = ufunc(_bitwise_and, name="bitwise_and", nin=2, nout=1, identity=-1, call=_bitwise_and)
3703-
bitwise_or = ufunc(_bitwise_or, name="bitwise_or", nin=2, nout=1, identity=0, call=_bitwise_or)
3704-
bitwise_xor = ufunc(_bitwise_xor, name="bitwise_xor", nin=2, nout=1, identity=0, call=_bitwise_xor)
3705-
logical_and = ufunc(_logical_and, name="logical_and", nin=2, nout=1, identity=True, call=_logical_and, reduce=_logical_and_reduce)
3706-
logical_or = ufunc(_logical_or, name="logical_or", nin=2, nout=1, identity=False, call=_logical_or, reduce=_logical_or_reduce)
3707-
logical_xor = ufunc(_logical_xor, name="logical_xor", nin=2, nout=1, identity=False, call=_logical_xor)
3708-
negative = ufunc(_negative, name="negative", nin=1, nout=1, call=_negative)
3709-
subtract = ufunc(_subtract, name="subtract", nin=2, nout=1, call=_subtract, at=_subtract_at)

jax/experimental/jax2tf/tests/jax2tf_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -979,8 +979,8 @@ def caller_jax(x):
979979
self.assertIn("my_test_function_jax/mul", self.TfToHlo(run_tf))
980980
else:
981981
graph_def = str(tf.function(run_tf, autograph=False).get_concrete_function().graph.as_graph_def())
982-
if "my_test_function_jax/pjit__multiply_/Mul" not in graph_def:
983-
self.assertIn("my_test_function_jax/jit__multiply_/Mul", graph_def)
982+
if "my_test_function_jax/pjit_multiply_/Mul" not in graph_def:
983+
self.assertIn("my_test_function_jax/jit_multiply_/Mul", graph_def)
984984

985985
def test_bfloat16_constant(self):
986986
# Re: https://github.com/jax-ml/jax/issues/3942

0 commit comments

Comments
 (0)