@@ -57,6 +57,24 @@ def _to_bool(x: Array) -> Array:
5757 return x if x .dtype == bool else lax .ne (x , _lax_const (x , 0 ))
5858
5959
60+ def unary_ufunc (func : Callable [[ArrayLike ], Array ]) -> ufunc :
61+ """An internal helper function for defining unary ufuncs."""
62+ func_jit = jit (func , inline = True )
63+ return ufunc (func_jit , name = func .__name__ , nin = 1 , nout = 1 , call = func_jit )
64+
65+
66+ def binary_ufunc (identity : Any , reduce : Callable [..., Any ] | None = None ,
67+ accumulate : Callable [..., Any ] | None = None ,
68+ at : Callable [..., Any ] | None = None ,
69+ reduceat : Callable [..., Any ] | None = None ) -> Callable [[Callable [[ArrayLike , ArrayLike ], Array ]], ufunc ]:
70+ """An internal helper function for defining binary ufuncs."""
71+ def decorator (func : Callable [[ArrayLike , ArrayLike ], Array ]) -> ufunc :
72+ func_jit = jit (func , inline = True )
73+ return ufunc (func_jit , name = func .__name__ , nin = 2 , nout = 1 , call = func_jit ,
74+ identity = identity , reduce = reduce , accumulate = accumulate , at = at , reduceat = reduceat )
75+ return decorator
76+
77+
6078@partial (jit , inline = True )
6179def fabs (x : ArrayLike , / ) -> Array :
6280 """Compute the element-wise absolute values of the real-valued input.
@@ -160,8 +178,8 @@ def invert(x: ArrayLike, /) -> Array:
160178 return lax .bitwise_not (* promote_args ('invert' , x ))
161179
162180
163- @partial ( jit , inline = True )
164- def _negative (x : ArrayLike , / ) -> Array :
181+ @unary_ufunc
182+ def negative (x : ArrayLike , / ) -> Array :
165183 """Return element-wise negative values of the input.
166184
167185 JAX implementation of :obj:`numpy.negative`.
@@ -1126,8 +1144,16 @@ def cbrt(x: ArrayLike, /) -> Array:
11261144 """
11271145 return lax .cbrt (* promote_args_inexact ('cbrt' , x ))
11281146
1129- @partial (jit , inline = True )
1130- def _add (x : ArrayLike , y : ArrayLike , / ) -> Array :
1147+ def _add_at (a : Array , indices : Any , b : ArrayLike ) -> Array :
1148+ """Implementation of jnp.add.at."""
1149+ if a .dtype == bool :
1150+ a = a .astype ('int32' )
1151+ b = lax .convert_element_type (b , bool ).astype ('int32' )
1152+ return a .at [indices ].add (b ).astype (bool )
1153+ return a .at [indices ].add (b )
1154+
1155+ @binary_ufunc (identity = 0 , reduce = reductions .sum , accumulate = reductions .cumsum , at = _add_at )
1156+ def add (x : ArrayLike , y : ArrayLike , / ) -> Array :
11311157 """Add two arrays element-wise.
11321158
11331159 JAX implementation of :obj:`numpy.add`. This is a universal function,
@@ -1156,8 +1182,17 @@ def _add(x: ArrayLike, y: ArrayLike, /) -> Array:
11561182 x , y = promote_args ("add" , x , y )
11571183 return lax .add (x , y ) if x .dtype != bool else lax .bitwise_or (x , y )
11581184
1159- @partial (jit , inline = True )
1160- def _multiply (x : ArrayLike , y : ArrayLike , / ) -> Array :
1185+ def _multiply_at (a : Array , indices : Any , b : ArrayLike ) -> Array :
1186+ """Implementation of jnp.multiply.at."""
1187+ if a .dtype == bool :
1188+ a = a .astype ('int32' )
1189+ b = lax .convert_element_type (b , bool ).astype ('int32' )
1190+ return a .at [indices ].mul (b ).astype (bool )
1191+ else :
1192+ return a .at [indices ].mul (b )
1193+
1194+ @binary_ufunc (identity = 1 , reduce = reductions .prod , accumulate = reductions .cumprod , at = _multiply_at )
1195+ def multiply (x : ArrayLike , y : ArrayLike , / ) -> Array :
11611196 """Multiply two arrays element-wise.
11621197
11631198 JAX implementation of :obj:`numpy.multiply`. This is a universal function,
@@ -1186,8 +1221,8 @@ def _multiply(x: ArrayLike, y: ArrayLike, /) -> Array:
11861221 x , y = promote_args ("multiply" , x , y )
11871222 return lax .mul (x , y ) if x .dtype != bool else lax .bitwise_and (x , y )
11881223
1189- @partial ( jit , inline = True )
1190- def _bitwise_and (x : ArrayLike , y : ArrayLike , / ) -> Array :
1224+ @binary_ufunc ( identity = - 1 )
1225+ def bitwise_and (x : ArrayLike , y : ArrayLike , / ) -> Array :
11911226 """Compute the bitwise AND operation elementwise.
11921227
11931228 JAX implementation of :obj:`numpy.bitwise_and`. This is a universal function,
@@ -1215,8 +1250,8 @@ def _bitwise_and(x: ArrayLike, y: ArrayLike, /) -> Array:
12151250 """
12161251 return lax .bitwise_and (* promote_args ("bitwise_and" , x , y ))
12171252
1218- @partial ( jit , inline = True )
1219- def _bitwise_or (x : ArrayLike , y : ArrayLike , / ) -> Array :
1253+ @binary_ufunc ( identity = 0 )
1254+ def bitwise_or (x : ArrayLike , y : ArrayLike , / ) -> Array :
12201255 """Compute the bitwise OR operation elementwise.
12211256
12221257 JAX implementation of :obj:`numpy.bitwise_or`. This is a universal function,
@@ -1244,8 +1279,8 @@ def _bitwise_or(x: ArrayLike, y: ArrayLike, /) -> Array:
12441279 """
12451280 return lax .bitwise_or (* promote_args ("bitwise_or" , x , y ))
12461281
1247- @partial ( jit , inline = True )
1248- def _bitwise_xor (x : ArrayLike , y : ArrayLike , / ) -> Array :
1282+ @binary_ufunc ( identity = 0 )
1283+ def bitwise_xor (x : ArrayLike , y : ArrayLike , / ) -> Array :
12491284 """Compute the bitwise XOR operation elementwise.
12501285
12511286 JAX implementation of :obj:`numpy.bitwise_xor`. This is a universal function,
@@ -1433,8 +1468,12 @@ def not_equal(x: ArrayLike, y: ArrayLike, /) -> Array:
14331468 return lax .ne (* promote_args ("not_equal" , x , y ))
14341469
14351470
1436- @partial (jit , inline = True )
1437- def _subtract (x : ArrayLike , y : ArrayLike , / ) -> Array :
1471+ def _subtract_at (a : Array , indices : Any , b : ArrayLike ) -> Array :
1472+ """Implementation of jnp.subtract.at."""
1473+ return a .at [indices ].subtract (b )
1474+
1475+ @binary_ufunc (identity = None , at = _subtract_at )
1476+ def subtract (x : ArrayLike , y : ArrayLike , / ) -> Array :
14381477 """Subtract two arrays element-wise.
14391478
14401479 JAX implementation of :obj:`numpy.subtract`. This is a universal function,
@@ -1754,8 +1793,17 @@ def spacing(x: ArrayLike, /) -> Array:
17541793
17551794
17561795# Logical ops
1757- @partial (jit , inline = True )
1758- def _logical_and (x : ArrayLike , y : ArrayLike , / ) -> Array :
1796+ def _logical_and_reduce (a : ArrayLike , axis : int = 0 , dtype : DTypeLike | None = None ,
1797+ out : None = None , keepdims : bool = False , initial : ArrayLike | None = None ,
1798+ where : ArrayLike | None = None ):
1799+ """Implementation of jnp.logical_and.reduce."""
1800+ if initial is not None :
1801+ raise ValueError ("initial argument not supported in jnp.logical_and.reduce()" )
1802+ result = reductions .all (a , axis = axis , out = out , keepdims = keepdims , where = where )
1803+ return result if dtype is None else result .astype (dtype )
1804+
1805+ @binary_ufunc (identity = True , reduce = _logical_and_reduce )
1806+ def logical_and (x : ArrayLike , y : ArrayLike , / ) -> Array :
17591807 """Compute the logical AND operation elementwise.
17601808
17611809 JAX implementation of :obj:`numpy.logical_and`. This is a universal function,
@@ -1774,8 +1822,18 @@ def _logical_and(x: ArrayLike, y: ArrayLike, /) -> Array:
17741822 """
17751823 return lax .bitwise_and (* map (_to_bool , promote_args ("logical_and" , x , y )))
17761824
1777- @partial (jit , inline = True )
1778- def _logical_or (x : ArrayLike , y : ArrayLike , / ) -> Array :
1825+
1826+ def _logical_or_reduce (a : ArrayLike , axis : int = 0 , dtype : DTypeLike | None = None ,
1827+ out : None = None , keepdims : bool = False , initial : ArrayLike | None = None ,
1828+ where : ArrayLike | None = None ):
1829+ """Implementation of jnp.logical_or.reduce."""
1830+ if initial is not None :
1831+ raise ValueError ("initial argument not supported in jnp.logical_or.reduce()" )
1832+ result = reductions .any (a , axis = axis , out = out , keepdims = keepdims , where = where )
1833+ return result if dtype is None else result .astype (dtype )
1834+
1835+ @binary_ufunc (identity = False , reduce = _logical_or_reduce )
1836+ def logical_or (x : ArrayLike , y : ArrayLike , / ) -> Array :
17791837 """Compute the logical OR operation elementwise.
17801838
17811839 JAX implementation of :obj:`numpy.logical_or`. This is a universal function,
@@ -1794,8 +1852,9 @@ def _logical_or(x: ArrayLike, y: ArrayLike, /) -> Array:
17941852 """
17951853 return lax .bitwise_or (* map (_to_bool , promote_args ("logical_or" , x , y )))
17961854
1797- @partial (jit , inline = True )
1798- def _logical_xor (x : ArrayLike , y : ArrayLike , / ) -> Array :
1855+
1856+ @binary_ufunc (identity = False )
1857+ def logical_xor (x : ArrayLike , y : ArrayLike , / ) -> Array :
17991858 """Compute the logical XOR operation elementwise.
18001859
18011860 JAX implementation of :obj:`numpy.logical_xor`. This is a universal function,
@@ -3653,57 +3712,3 @@ def _sinc_maclaurin(k, x):
36533712def _sinc_maclaurin_jvp (k , primals , tangents ):
36543713 (x ,), (t ,) = primals , tangents
36553714 return _sinc_maclaurin (k , x ), _sinc_maclaurin (k + 1 , x ) * t
3656-
3657-
3658- def _logical_and_reduce (a : ArrayLike , axis : int = 0 , dtype : DTypeLike | None = None ,
3659- out : None = None , keepdims : bool = False , initial : ArrayLike | None = None ,
3660- where : ArrayLike | None = None ):
3661- if initial is not None :
3662- raise ValueError ("initial argument not supported in jnp.logical_and.reduce()" )
3663- result = reductions .all (a , axis = axis , out = out , keepdims = keepdims , where = where )
3664- return result if dtype is None else result .astype (dtype )
3665-
3666-
3667- def _logical_or_reduce (a : ArrayLike , axis : int = 0 , dtype : DTypeLike | None = None ,
3668- out : None = None , keepdims : bool = False , initial : ArrayLike | None = None ,
3669- where : ArrayLike | None = None ):
3670- if initial is not None :
3671- raise ValueError ("initial argument not supported in jnp.logical_or.reduce()" )
3672- result = reductions .any (a , axis = axis , out = out , keepdims = keepdims , where = where )
3673- return result if dtype is None else result .astype (dtype )
3674-
3675- def _add_at (a : Array , indices : Any , b : ArrayLike ):
3676- if a .dtype == bool :
3677- a = a .astype ('int32' )
3678- b = lax .convert_element_type (b , bool ).astype ('int32' )
3679- return a .at [indices ].add (b ).astype (bool )
3680- return a .at [indices ].add (b )
3681-
3682- def _subtract_at (a : Array , indices : Any , b : ArrayLike ):
3683- return a .at [indices ].subtract (b )
3684-
3685- def _multiply_at (a : Array , indices : Any , b : ArrayLike ):
3686- if a .dtype == bool :
3687- a = a .astype ('int32' )
3688- b = lax .convert_element_type (b , bool ).astype ('int32' )
3689- return a .at [indices ].mul (b ).astype (bool )
3690- else :
3691- return a .at [indices ].mul (b )
3692-
3693- # Generate ufunc interfaces for several common binary functions.
3694- # We start with binary ufuncs that have well-defined identities.'
3695- # TODO(jakevdp): wrap more ufuncs. Possibly define a decorator for convenience?
3696- # TODO(jakevdp): optimize some implementations.
3697- # - define add.at/multiply.at in terms of scatter_add/scatter_mul
3698- # - define add.reduceat/multiply.reduceat in terms of segment_sum/segment_prod
3699- # - define all monoidal reductions in terms of lax.reduce
3700- add = ufunc (_add , name = "add" , nin = 2 , nout = 1 , identity = 0 , call = _add , reduce = reductions .sum , accumulate = reductions .cumsum , at = _add_at )
3701- multiply = ufunc (_multiply , name = "multiply" , nin = 2 , nout = 1 , identity = 1 , call = _multiply , reduce = reductions .prod , accumulate = reductions .cumprod , at = _multiply_at )
3702- bitwise_and = ufunc (_bitwise_and , name = "bitwise_and" , nin = 2 , nout = 1 , identity = - 1 , call = _bitwise_and )
3703- bitwise_or = ufunc (_bitwise_or , name = "bitwise_or" , nin = 2 , nout = 1 , identity = 0 , call = _bitwise_or )
3704- bitwise_xor = ufunc (_bitwise_xor , name = "bitwise_xor" , nin = 2 , nout = 1 , identity = 0 , call = _bitwise_xor )
3705- logical_and = ufunc (_logical_and , name = "logical_and" , nin = 2 , nout = 1 , identity = True , call = _logical_and , reduce = _logical_and_reduce )
3706- logical_or = ufunc (_logical_or , name = "logical_or" , nin = 2 , nout = 1 , identity = False , call = _logical_or , reduce = _logical_or_reduce )
3707- logical_xor = ufunc (_logical_xor , name = "logical_xor" , nin = 2 , nout = 1 , identity = False , call = _logical_xor )
3708- negative = ufunc (_negative , name = "negative" , nin = 1 , nout = 1 , call = _negative )
3709- subtract = ufunc (_subtract , name = "subtract" , nin = 2 , nout = 1 , call = _subtract , at = _subtract_at )
0 commit comments