@@ -63,7 +63,7 @@ def apply_where( # type: ignore[no-any-explicit,no-any-decorated] # numpydoc ig
6363def apply_where ( # type: ignore[no-any-explicit,misc] # numpydoc ignore=PR01,PR02
6464 cond : Array ,
6565 f1 : Callable [..., Array ],
66- f2 : Callable [..., Array ] | Array ,
66+ f2 : Callable [..., Array ] | Array , # optional positional argument
6767 / ,
6868 * args : Array ,
6969 fill_value : Array | int | float | complex | bool | None = None ,
@@ -119,37 +119,36 @@ def apply_where( # type: ignore[no-any-explicit,misc] # numpydoc ignore=PR01,PR
119119 mutually_exc_msg = "Exactly one of `fill_value` or `f2` must be given."
120120 if is_array_api_obj (f2 ):
121121 args = (cast (Array , f2 ), * args )
122- if fill_value is not None :
123- raise TypeError (mutually_exc_msg )
124- f2_ : Callable [..., Array ] | None = None # type: ignore[no-any-explicit]
125- else :
126- if not callable (f2 ):
127- msg = "Third parameter must be either an Array or callable."
128- raise ValueError (msg )
129- f2_ = cast (Callable [..., Array ], f2 ) # type: ignore[no-any-explicit]
122+ f2 = None
130123 if fill_value is None :
131124 raise TypeError (mutually_exc_msg )
132125 if getattr (fill_value , "ndim" , 0 ) != 0 :
133126 msg = "`fill_value` must be a scalar."
134127 raise ValueError (msg )
135- del f2
128+ elif callable (f2 ):
129+ if fill_value is not None :
130+ raise TypeError (mutually_exc_msg )
131+ else :
132+ msg = "Third parameter must be either an Array or callable."
133+ raise ValueError (msg )
136134 if not args :
137135 msg = "Must give at least one input array."
138136 raise TypeError (msg )
137+ # End argument parsing
139138
140139 xp = array_namespace (cond , * args ) if xp is None else xp
141140
142141 if not is_dask_namespace (xp ):
143142 return _apply_where (
144- cond , f1 , f2_ , * args , fill_value = fill_value , dtype = None , xp = xp
143+ cond , f1 , f2 , * args , fill_value = fill_value , dtype = None , xp = xp
145144 )
146145
147146 # Dask-specific code from here onwards
148147 metas = [arg ._meta for arg in args ] # pylint: disable=protected-access
149148 meta_xp = array_namespace (cond ._meta , * metas ) # pylint: disable=protected-access
150149 # Determine output dtype
151- if f2_ is not None :
152- dtype = meta_xp .result_type (f1 (* metas ), f2_ (* metas ))
150+ if f2 is not None :
151+ dtype = meta_xp .result_type (f1 (* metas ), f2 (* metas ))
153152 elif is_dask_array (fill_value ):
154153 dtype = meta_xp .result_type (f1 (* metas ), cast (Array , fill_value )._meta ) # pylint: disable=protected-access
155154 else :
@@ -161,7 +160,7 @@ def apply_where( # type: ignore[no-any-explicit,misc] # numpydoc ignore=PR01,PR
161160 partial (_apply_where , dtype = dtype , xp = meta_xp ),
162161 cond ,
163162 f1 ,
164- f2_ ,
163+ f2 ,
165164 * args ,
166165 fill_value = fill_value ,
167166 dtype = dtype ,
0 commit comments