@@ -1432,11 +1432,37 @@ def not_equal(x: ArrayLike, y: ArrayLike, /) -> Array:
14321432 """
14331433 return lax .ne (* promote_args ("not_equal" , x , y ))
14341434
1435- @ implements ( np . subtract , module = 'numpy' )
1435+
14361436@partial (jit , inline = True )
1437- def subtract (x : ArrayLike , y : ArrayLike , / ) -> Array :
1437+ def _subtract (x : ArrayLike , y : ArrayLike , / ) -> Array :
1438+ """Subtract two arrays element-wise.
1439+
1440+ JAX implementation of :obj:`numpy.subtract`. This is a universal function,
1441+ and supports the additional APIs described at :class:`jax.numpy.ufunc`.
1442+ This function provides the implementation of the ``-`` operator for
1443+ JAX arrays.
1444+
1445+ Args:
1446+ x, y: arrays to subtract. Must be broadcastable to a common shape.
1447+
1448+ Returns:
1449+ Array containing the result of the element-wise subtraction.
1450+
1451+ Examples:
1452+ Calling ``subtract`` explicitly:
1453+
1454+ >>> x = jnp.arange(4)
1455+ >>> jnp.subtract(x, 10)
1456+ Array([-10, -9, -8, -7], dtype=int32)
1457+
1458+ Calling ``subtract`` via the ``-`` operator:
1459+
1460+ >>> x - 10
1461+ Array([-10, -9, -8, -7], dtype=int32)
1462+ """
14381463 return lax .sub (* promote_args ("subtract" , x , y ))
14391464
1465+
14401466@implements (np .arctan2 , module = 'numpy' )
14411467@partial (jit , inline = True )
14421468def arctan2 (x1 : ArrayLike , x2 : ArrayLike , / ) -> Array :
@@ -3604,6 +3630,9 @@ def _add_at(a: Array, indices: Any, b: ArrayLike):
36043630 return a .at [indices ].add (b ).astype (bool )
36053631 return a .at [indices ].add (b )
36063632
3633+ def _subtract_at (a : Array , indices : Any , b : ArrayLike ):
3634+ return a .at [indices ].subtract (b )
3635+
36073636def _multiply_at (a : Array , indices : Any , b : ArrayLike ):
36083637 if a .dtype == bool :
36093638 a = a .astype ('int32' )
@@ -3628,3 +3657,4 @@ def _multiply_at(a: Array, indices: Any, b: ArrayLike):
36283657logical_or = ufunc (_logical_or , name = "logical_or" , nin = 2 , nout = 1 , identity = False , call = _logical_or , reduce = _logical_or_reduce )
36293658logical_xor = ufunc (_logical_xor , name = "logical_xor" , nin = 2 , nout = 1 , identity = False , call = _logical_xor )
36303659negative = ufunc (_negative , name = "negative" , nin = 1 , nout = 1 , call = _negative )
3660+ subtract = ufunc (_subtract , name = "subtract" , nin = 2 , nout = 1 , call = _subtract , at = _subtract_at )
0 commit comments