@@ -81,6 +81,20 @@ def _promote_integer_dtype(dtype: DTypeLike) -> DTypeLike:
8181 return dtypes .int_
8282 return dtype
8383
84+ def check_where (name : str , where : ArrayLike | None ) -> Array | None :
85+ if where is None :
86+ return where
87+ check_arraylike (name , where )
88+ where_arr = lax_internal .asarray (where )
89+ if where_arr .dtype != bool :
90+ # Deprecation added 2024-12-05
91+ deprecations .warn (
92+ 'jax-numpy-reduction-non-boolean-where' ,
93+ f"jnp.{ name } : where must be None or a boolean array; got dtype={ where_arr .dtype } ." ,
94+ stacklevel = 2 )
95+ return where_arr .astype (bool )
96+ return where_arr
97+
8498
8599ReductionOp = Callable [[Any , Any ], Any ]
86100
@@ -101,6 +115,7 @@ def _reduction(a: ArrayLike, name: str, op: ReductionOp, init_val: ArrayLike,
101115 if out is not None :
102116 raise NotImplementedError (f"The 'out' argument to jnp.{ name } is not supported." )
103117 check_arraylike (name , a )
118+ where_ = check_where (name , where_ )
104119 dtypes .check_user_dtype_supported (dtype , name )
105120 axis = core .concrete_or_error (None , axis , f"axis argument to jnp.{ name } ()." )
106121
@@ -730,6 +745,8 @@ def _logsumexp(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
730745 if out is not None :
731746 raise NotImplementedError ("The 'out' argument to jnp.logaddexp.reduce is not supported." )
732747 dtypes .check_user_dtype_supported (dtype , "jnp.logaddexp.reduce" )
748+ check_arraylike ("logsumexp" , a )
749+ where = check_where ("logsumexp" , where )
733750 a_arr , = promote_dtypes_inexact (a )
734751 pos_dims , dims = _reduction_dims (a_arr , axis )
735752 amax = max (a_arr .real , axis = dims , keepdims = keepdims , where = where , initial = - np .inf )
@@ -748,6 +765,8 @@ def _logsumexp2(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
748765 if out is not None :
749766 raise NotImplementedError ("The 'out' argument to jnp.logaddexp2.reduce is not supported." )
750767 dtypes .check_user_dtype_supported (dtype , "jnp.logaddexp2.reduce" )
768+ check_arraylike ("logsumexp2" , a )
769+ where = check_where ("logsumexp2" , where )
751770 ln2 = float (np .log (2 ))
752771 if initial is not None :
753772 initial *= ln2
@@ -850,6 +869,7 @@ def _mean(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
850869 upcast_f16_for_computation : bool = True ,
851870 where : ArrayLike | None = None ) -> Array :
852871 check_arraylike ("mean" , a )
872+ where = check_where ("mean" , where )
853873 if out is not None :
854874 raise NotImplementedError ("The 'out' argument to jnp.mean is not supported." )
855875
@@ -1087,6 +1107,7 @@ def _var(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
10871107 out : None = None , correction : int | float = 0 , keepdims : bool = False , * ,
10881108 where : ArrayLike | None = None ) -> Array :
10891109 check_arraylike ("var" , a )
1110+ where = check_where ("var" , where )
10901111 dtypes .check_user_dtype_supported (dtype , "var" )
10911112 if out is not None :
10921113 raise NotImplementedError ("The 'out' argument to jnp.var is not supported." )
@@ -1224,6 +1245,7 @@ def _std(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
12241245 out : None = None , correction : int | float = 0 , keepdims : bool = False , * ,
12251246 where : ArrayLike | None = None ) -> Array :
12261247 check_arraylike ("std" , a )
1248+ where = check_where ("std" , where )
12271249 dtypes .check_user_dtype_supported (dtype , "std" )
12281250 if dtype is not None and not dtypes .issubdtype (dtype , np .inexact ):
12291251 raise ValueError (f"dtype argument to jnp.std must be inexact; got { dtype } " )
@@ -1330,13 +1352,15 @@ def count_nonzero(a: ArrayLike, axis: Axis = None,
13301352
13311353def _nan_reduction (a : ArrayLike , name : str , jnp_reduction : Callable [..., Array ],
13321354 init_val : ArrayLike , nan_if_all_nan : bool ,
1333- axis : Axis = None , keepdims : bool = False , ** kwargs ) -> Array :
1355+ axis : Axis = None , keepdims : bool = False , where : ArrayLike | None = None ,
1356+ ** kwargs ) -> Array :
13341357 check_arraylike (name , a )
1358+ where = check_where (name , where )
13351359 if not dtypes .issubdtype (dtypes .dtype (a ), np .inexact ):
1336- return jnp_reduction (a , axis = axis , keepdims = keepdims , ** kwargs )
1360+ return jnp_reduction (a , axis = axis , keepdims = keepdims , where = where , ** kwargs )
13371361
13381362 out = jnp_reduction (_where (lax_internal ._isnan (a ), _reduction_init_val (a , init_val ), a ),
1339- axis = axis , keepdims = keepdims , ** kwargs )
1363+ axis = axis , keepdims = keepdims , where = where , ** kwargs )
13401364 if nan_if_all_nan :
13411365 return _where (all (lax_internal ._isnan (a ), axis = axis , keepdims = keepdims ),
13421366 _lax_const (a , np .nan ), out )
@@ -1755,6 +1779,7 @@ def nanmean(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out
17551779 Array([[nan, nan, nan, nan]], dtype=float32)
17561780 """
17571781 check_arraylike ("nanmean" , a )
1782+ where = check_where ("nanmean" , where )
17581783 if out is not None :
17591784 raise NotImplementedError ("The 'out' argument to jnp.nanmean is not supported." )
17601785 if dtypes .issubdtype (dtypes .dtype (a ), np .bool_ ) or dtypes .issubdtype (dtypes .dtype (a ), np .integer ):
@@ -1848,6 +1873,7 @@ def nanvar(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out:
18481873 [4. ]], dtype=float32)
18491874 """
18501875 check_arraylike ("nanvar" , a )
1876+ where = check_where ("nanvar" , where )
18511877 dtypes .check_user_dtype_supported (dtype , "nanvar" )
18521878 if out is not None :
18531879 raise NotImplementedError ("The 'out' argument to jnp.nanvar is not supported." )
@@ -1943,6 +1969,7 @@ def nanstd(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out:
19431969 Array([[0.5, 0.5, 0. , 0. ]], dtype=float32)
19441970 """
19451971 check_arraylike ("nanstd" , a )
1972+ where = check_where ("nanstd" , where )
19461973 dtypes .check_user_dtype_supported (dtype , "nanstd" )
19471974 if out is not None :
19481975 raise NotImplementedError ("The 'out' argument to jnp.nanstd is not supported." )
0 commit comments