Skip to content

Commit f73fa7a

Browse files
Merge pull request jax-ml#25290 from jakevdp:reduction-where
PiperOrigin-RevId: 703182008
2 parents a71f9a6 + 29a8cce commit f73fa7a

File tree

3 files changed

+59
-3
lines changed

3 files changed

+59
-3
lines changed

jax/_src/deprecations.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,5 +129,6 @@ def warn(deprecation_id: str, message: str, stacklevel: int) -> None:
129129
register('jax-numpy-linalg-matrix_rank-tol')
130130
register('jax-numpy-linalg-pinv-rcond')
131131
register('jax-numpy-quantile-interpolation')
132+
register('jax-numpy-reduction-non-boolean-where')
132133
register('jax-numpy-trimzeros-not-1d-array')
133134
register('pallas-gpu-triton')

jax/_src/numpy/reductions.py

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,20 @@ def _promote_integer_dtype(dtype: DTypeLike) -> DTypeLike:
8181
return dtypes.int_
8282
return dtype
8383

84+
def check_where(name: str, where: ArrayLike | None) -> Array | None:
85+
if where is None:
86+
return where
87+
check_arraylike(name, where)
88+
where_arr = lax_internal.asarray(where)
89+
if where_arr.dtype != bool:
90+
# Deprecation added 2024-12-05
91+
deprecations.warn(
92+
'jax-numpy-reduction-non-boolean-where',
93+
f"jnp.{name}: where must be None or a boolean array; got dtype={where_arr.dtype}.",
94+
stacklevel=2)
95+
return where_arr.astype(bool)
96+
return where_arr
97+
8498

8599
ReductionOp = Callable[[Any, Any], Any]
86100

@@ -101,6 +115,7 @@ def _reduction(a: ArrayLike, name: str, op: ReductionOp, init_val: ArrayLike,
101115
if out is not None:
102116
raise NotImplementedError(f"The 'out' argument to jnp.{name} is not supported.")
103117
check_arraylike(name, a)
118+
where_ = check_where(name, where_)
104119
dtypes.check_user_dtype_supported(dtype, name)
105120
axis = core.concrete_or_error(None, axis, f"axis argument to jnp.{name}().")
106121

@@ -730,6 +745,8 @@ def _logsumexp(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
730745
if out is not None:
731746
raise NotImplementedError("The 'out' argument to jnp.logaddexp.reduce is not supported.")
732747
dtypes.check_user_dtype_supported(dtype, "jnp.logaddexp.reduce")
748+
check_arraylike("logsumexp", a)
749+
where = check_where("logsumexp", where)
733750
a_arr, = promote_dtypes_inexact(a)
734751
pos_dims, dims = _reduction_dims(a_arr, axis)
735752
amax = max(a_arr.real, axis=dims, keepdims=keepdims, where=where, initial=-np.inf)
@@ -748,6 +765,8 @@ def _logsumexp2(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
748765
if out is not None:
749766
raise NotImplementedError("The 'out' argument to jnp.logaddexp2.reduce is not supported.")
750767
dtypes.check_user_dtype_supported(dtype, "jnp.logaddexp2.reduce")
768+
check_arraylike("logsumexp2", a)
769+
where = check_where("logsumexp2", where)
751770
ln2 = float(np.log(2))
752771
if initial is not None:
753772
initial *= ln2
@@ -850,6 +869,7 @@ def _mean(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
850869
upcast_f16_for_computation: bool = True,
851870
where: ArrayLike | None = None) -> Array:
852871
check_arraylike("mean", a)
872+
where = check_where("mean", where)
853873
if out is not None:
854874
raise NotImplementedError("The 'out' argument to jnp.mean is not supported.")
855875

@@ -1087,6 +1107,7 @@ def _var(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
10871107
out: None = None, correction: int | float = 0, keepdims: bool = False, *,
10881108
where: ArrayLike | None = None) -> Array:
10891109
check_arraylike("var", a)
1110+
where = check_where("var", where)
10901111
dtypes.check_user_dtype_supported(dtype, "var")
10911112
if out is not None:
10921113
raise NotImplementedError("The 'out' argument to jnp.var is not supported.")
@@ -1224,6 +1245,7 @@ def _std(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
12241245
out: None = None, correction: int | float = 0, keepdims: bool = False, *,
12251246
where: ArrayLike | None = None) -> Array:
12261247
check_arraylike("std", a)
1248+
where = check_where("std", where)
12271249
dtypes.check_user_dtype_supported(dtype, "std")
12281250
if dtype is not None and not dtypes.issubdtype(dtype, np.inexact):
12291251
raise ValueError(f"dtype argument to jnp.std must be inexact; got {dtype}")
@@ -1330,13 +1352,15 @@ def count_nonzero(a: ArrayLike, axis: Axis = None,
13301352

13311353
def _nan_reduction(a: ArrayLike, name: str, jnp_reduction: Callable[..., Array],
13321354
init_val: ArrayLike, nan_if_all_nan: bool,
1333-
axis: Axis = None, keepdims: bool = False, **kwargs) -> Array:
1355+
axis: Axis = None, keepdims: bool = False, where: ArrayLike | None = None,
1356+
**kwargs) -> Array:
13341357
check_arraylike(name, a)
1358+
where = check_where(name, where)
13351359
if not dtypes.issubdtype(dtypes.dtype(a), np.inexact):
1336-
return jnp_reduction(a, axis=axis, keepdims=keepdims, **kwargs)
1360+
return jnp_reduction(a, axis=axis, keepdims=keepdims, where=where, **kwargs)
13371361

13381362
out = jnp_reduction(_where(lax_internal._isnan(a), _reduction_init_val(a, init_val), a),
1339-
axis=axis, keepdims=keepdims, **kwargs)
1363+
axis=axis, keepdims=keepdims, where=where, **kwargs)
13401364
if nan_if_all_nan:
13411365
return _where(all(lax_internal._isnan(a), axis=axis, keepdims=keepdims),
13421366
_lax_const(a, np.nan), out)
@@ -1755,6 +1779,7 @@ def nanmean(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out
17551779
Array([[nan, nan, nan, nan]], dtype=float32)
17561780
"""
17571781
check_arraylike("nanmean", a)
1782+
where = check_where("nanmean", where)
17581783
if out is not None:
17591784
raise NotImplementedError("The 'out' argument to jnp.nanmean is not supported.")
17601785
if dtypes.issubdtype(dtypes.dtype(a), np.bool_) or dtypes.issubdtype(dtypes.dtype(a), np.integer):
@@ -1848,6 +1873,7 @@ def nanvar(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out:
18481873
[4. ]], dtype=float32)
18491874
"""
18501875
check_arraylike("nanvar", a)
1876+
where = check_where("nanvar", where)
18511877
dtypes.check_user_dtype_supported(dtype, "nanvar")
18521878
if out is not None:
18531879
raise NotImplementedError("The 'out' argument to jnp.nanvar is not supported.")
@@ -1943,6 +1969,7 @@ def nanstd(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out:
19431969
Array([[0.5, 0.5, 0. , 0. ]], dtype=float32)
19441970
"""
19451971
check_arraylike("nanstd", a)
1972+
where = check_where("nanstd", where)
19461973
dtypes.check_user_dtype_supported(dtype, "nanstd")
19471974
if out is not None:
19481975
raise NotImplementedError("The 'out' argument to jnp.nanstd is not supported.")

tests/lax_numpy_reducers_test.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -448,6 +448,34 @@ def np_fun(x):
448448
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, atol=tol, rtol=tol)
449449
self._CompileAndCheck(jnp_fun, args_maker)
450450

451+
@jtu.sample_product(rec=JAX_REDUCER_INITIAL_RECORDS)
452+
def testReducerWhereNonBooleanErrorInitial(self, rec):
453+
dtype = rec.dtypes[0]
454+
x = jnp.zeros((10,), dtype)
455+
where = jnp.ones(10, dtype=int)
456+
func = getattr(jnp, rec.name)
457+
def assert_warns_or_errors(msg):
458+
if deprecations.is_accelerated("jax-numpy-reduction-non-boolean-where"):
459+
return self.assertRaisesRegex(ValueError, msg)
460+
else:
461+
return self.assertWarnsRegex(DeprecationWarning, msg)
462+
with assert_warns_or_errors(f"jnp.{rec.name}: where must be None or a boolean array"):
463+
func(x, where=where, initial=jnp.array(0, dtype=dtype))
464+
465+
@jtu.sample_product(rec=JAX_REDUCER_WHERE_NO_INITIAL_RECORDS)
466+
def testReducerWhereNonBooleanErrorNoInitial(self, rec):
467+
dtype = rec.dtypes[0]
468+
x = jnp.zeros((10,), dtype)
469+
where = jnp.ones(10, dtype=int)
470+
func = getattr(jnp, rec.name)
471+
def assert_warns_or_errors(msg):
472+
if deprecations.is_accelerated("jax-numpy-reduction-non-boolean-where"):
473+
return self.assertRaisesRegex(ValueError, msg)
474+
else:
475+
return self.assertWarnsRegex(DeprecationWarning, msg)
476+
with assert_warns_or_errors(f"jnp.{rec.name}: where must be None or a boolean array"):
477+
func(x, where=where)
478+
451479
@parameterized.parameters(itertools.chain.from_iterable(
452480
jtu.sample_product_testcases(
453481
[dict(name=rec.name, rng_factory=rec.rng_factory, inexact=rec.inexact,

0 commit comments

Comments
 (0)