Skip to content

Commit 66971a2

Browse files
committed
Fix jnp.diff for boolean inputs
1 parent 16fca38 commit 66971a2

File tree

2 files changed

+7
-1
lines changed

2 files changed

+7
-1
lines changed

jax/_src/numpy/lax_numpy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1787,7 +1787,7 @@ def diff(a: ArrayLike, n: int = 1, axis: int = -1,
17871787
slice1_tuple = tuple(slice1)
17881788
slice2_tuple = tuple(slice2)
17891789

1790-
op = operator.not_equal if arr.dtype == np.bool_ else operator.sub
1790+
op = operator.ne if arr.dtype == np.bool_ else operator.sub
17911791
for _ in range(n):
17921792
arr = op(arr[slice1_tuple], arr[slice2_tuple])
17931793

tests/lax_numpy_test.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2993,6 +2993,12 @@ def np_fun(x, n=n, axis=axis, prepend=prepend, append=append):
29932993
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False)
29942994
self._CompileAndCheck(jnp_fun, args_maker)
29952995

2996+
def testDiffBool(self):
2997+
rng = jtu.rand_default(self.rng())
2998+
args_maker = lambda: [rng((10,), bool)]
2999+
self._CheckAgainstNumpy(np.diff, jnp.diff, args_maker, check_dtypes=False)
3000+
self._CompileAndCheck(jnp.diff, args_maker)
3001+
29963002
def testDiffPrepoendScalar(self):
29973003
# Regression test for https://github.com/jax-ml/jax/issues/19362
29983004
x = jnp.arange(10)

0 commit comments

Comments
 (0)