Skip to content

Commit 4a5ca2f

Browse files
Merge pull request jax-ml#24400 from jakevdp:subtract-ufunc
PiperOrigin-RevId: 688190106
2 parents 65307ab + 6467d03 commit 4a5ca2f

File tree

5 files changed

+39
-4
lines changed

5 files changed

+39
-4
lines changed

jax/_src/basearray.pyi

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,8 @@ class _IndexUpdateRef:
284284
mode: str | None = None, fill_value: StaticScalar | None = None) -> Array: ...
285285
def add(self, values: Any, indices_are_sorted: bool = False,
286286
unique_indices: bool = False, mode: str | None = None) -> Array: ...
287+
def subtract(self, values: Any, *, indices_are_sorted: bool = False,
288+
unique_indices: bool = False, mode: str | None = None) -> Array: ...
287289
def mul(self, values: Any, indices_are_sorted: bool = False,
288290
unique_indices: bool = False, mode: str | None = None) -> Array: ...
289291
def multiply(self, values: Any, indices_are_sorted: bool = False,

jax/_src/numpy/lax_numpy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1787,7 +1787,7 @@ def diff(a: ArrayLike, n: int = 1, axis: int = -1,
17871787
slice1_tuple = tuple(slice1)
17881788
slice2_tuple = tuple(slice2)
17891789

1790-
op = ufuncs.not_equal if arr.dtype == np.bool_ else ufuncs.subtract
1790+
op = operator.not_equal if arr.dtype == np.bool_ else operator.sub
17911791
for _ in range(n):
17921792
arr = op(arr[slice1_tuple], arr[slice2_tuple])
17931793

jax/_src/numpy/ufuncs.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1432,11 +1432,37 @@ def not_equal(x: ArrayLike, y: ArrayLike, /) -> Array:
14321432
"""
14331433
return lax.ne(*promote_args("not_equal", x, y))
14341434

1435-
@implements(np.subtract, module='numpy')
1435+
14361436
@partial(jit, inline=True)
1437-
def subtract(x: ArrayLike, y: ArrayLike, /) -> Array:
1437+
def _subtract(x: ArrayLike, y: ArrayLike, /) -> Array:
1438+
"""Subtract two arrays element-wise.
1439+
1440+
JAX implementation of :obj:`numpy.subtract`. This is a universal function,
1441+
and supports the additional APIs described at :class:`jax.numpy.ufunc`.
1442+
This function provides the implementation of the ``-`` operator for
1443+
JAX arrays.
1444+
1445+
Args:
1446+
x, y: arrays to subtract. Must be broadcastable to a common shape.
1447+
1448+
Returns:
1449+
Array containing the result of the element-wise subtraction.
1450+
1451+
Examples:
1452+
Calling ``subtract`` explicitly:
1453+
1454+
>>> x = jnp.arange(4)
1455+
>>> jnp.subtract(x, 10)
1456+
Array([-10, -9, -8, -7], dtype=int32)
1457+
1458+
Calling ``subtract`` via the ``-`` operator:
1459+
1460+
>>> x - 10
1461+
Array([-10, -9, -8, -7], dtype=int32)
1462+
"""
14381463
return lax.sub(*promote_args("subtract", x, y))
14391464

1465+
14401466
@implements(np.arctan2, module='numpy')
14411467
@partial(jit, inline=True)
14421468
def arctan2(x1: ArrayLike, x2: ArrayLike, /) -> Array:
@@ -3604,6 +3630,9 @@ def _add_at(a: Array, indices: Any, b: ArrayLike):
36043630
return a.at[indices].add(b).astype(bool)
36053631
return a.at[indices].add(b)
36063632

3633+
def _subtract_at(a: Array, indices: Any, b: ArrayLike):
3634+
return a.at[indices].subtract(b)
3635+
36073636
def _multiply_at(a: Array, indices: Any, b: ArrayLike):
36083637
if a.dtype == bool:
36093638
a = a.astype('int32')
@@ -3628,3 +3657,4 @@ def _multiply_at(a: Array, indices: Any, b: ArrayLike):
36283657
logical_or = ufunc(_logical_or, name="logical_or", nin=2, nout=1, identity=False, call=_logical_or, reduce=_logical_or_reduce)
36293658
logical_xor = ufunc(_logical_xor, name="logical_xor", nin=2, nout=1, identity=False, call=_logical_xor)
36303659
negative = ufunc(_negative, name="negative", nin=1, nout=1, call=_negative)
3660+
subtract = ufunc(_subtract, name="subtract", nin=2, nout=1, call=_subtract, at=_subtract_at)

jax/numpy/__init__.pyi

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -829,7 +829,7 @@ def stack(
829829
def std(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike = ...,
830830
out: None = ..., ddof: int = ..., keepdims: builtins.bool = ..., *,
831831
where: ArrayLike | None = ..., correction: int | float | None = ...) -> Array: ...
832-
def subtract(x: ArrayLike, y: ArrayLike, /) -> Array: ...
832+
subtract: BinaryUfunc
833833
def sum(
834834
a: ArrayLike,
835835
axis: _Axis = ...,

tests/lax_numpy_ufuncs_test.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,9 @@ def test_binary_ufunc_reduce(self, name, shape, axis, dtype):
250250
jnp_fun = getattr(jnp, name)
251251
np_fun = getattr(np, name)
252252

253+
if jnp_fun.identity is None and axis is None and len(shape) > 1:
254+
self.skipTest("Multiple-axis reduction over non-reorderable ufunc.")
255+
253256
jnp_fun_reduce = partial(jnp_fun.reduce, axis=axis)
254257
np_fun_reduce = partial(np_fun.reduce, axis=axis)
255258

0 commit comments

Comments
 (0)