Skip to content

Commit e5147b2

Browse files
committed
tweak
1 parent 3e669a2 commit e5147b2

File tree

1 file changed

+13
-14
lines changed

1 file changed

+13
-14
lines changed

src/array_api_extra/_lib/_funcs.py

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def apply_where( # type: ignore[no-any-explicit,no-any-decorated] # numpydoc ig
6363
def apply_where( # type: ignore[no-any-explicit,misc] # numpydoc ignore=PR01,PR02
6464
cond: Array,
6565
f1: Callable[..., Array],
66-
f2: Callable[..., Array] | Array,
66+
f2: Callable[..., Array] | Array, # optional positional argument
6767
/,
6868
*args: Array,
6969
fill_value: Array | int | float | complex | bool | None = None,
@@ -119,37 +119,36 @@ def apply_where( # type: ignore[no-any-explicit,misc] # numpydoc ignore=PR01,PR
119119
mutually_exc_msg = "Exactly one of `fill_value` or `f2` must be given."
120120
if is_array_api_obj(f2):
121121
args = (cast(Array, f2), *args)
122-
if fill_value is not None:
123-
raise TypeError(mutually_exc_msg)
124-
f2_: Callable[..., Array] | None = None # type: ignore[no-any-explicit]
125-
else:
126-
if not callable(f2):
127-
msg = "Third parameter must be either an Array or callable."
128-
raise ValueError(msg)
129-
f2_ = cast(Callable[..., Array], f2) # type: ignore[no-any-explicit]
122+
f2 = None
130123
if fill_value is None:
131124
raise TypeError(mutually_exc_msg)
132125
if getattr(fill_value, "ndim", 0) != 0:
133126
msg = "`fill_value` must be a scalar."
134127
raise ValueError(msg)
135-
del f2
128+
elif callable(f2):
129+
if fill_value is not None:
130+
raise TypeError(mutually_exc_msg)
131+
else:
132+
msg = "Third parameter must be either an Array or callable."
133+
raise ValueError(msg)
136134
if not args:
137135
msg = "Must give at least one input array."
138136
raise TypeError(msg)
137+
# End argument parsing
139138

140139
xp = array_namespace(cond, *args) if xp is None else xp
141140

142141
if not is_dask_namespace(xp):
143142
return _apply_where(
144-
cond, f1, f2_, *args, fill_value=fill_value, dtype=None, xp=xp
143+
cond, f1, f2, *args, fill_value=fill_value, dtype=None, xp=xp
145144
)
146145

147146
# Dask-specific code from here onwards
148147
metas = [arg._meta for arg in args] # pylint: disable=protected-access
149148
meta_xp = array_namespace(cond._meta, *metas) # pylint: disable=protected-access
150149
# Determine output dtype
151-
if f2_ is not None:
152-
dtype = meta_xp.result_type(f1(*metas), f2_(*metas))
150+
if f2 is not None:
151+
dtype = meta_xp.result_type(f1(*metas), f2(*metas))
153152
elif is_dask_array(fill_value):
154153
dtype = meta_xp.result_type(f1(*metas), cast(Array, fill_value)._meta) # pylint: disable=protected-access
155154
else:
@@ -161,7 +160,7 @@ def apply_where( # type: ignore[no-any-explicit,misc] # numpydoc ignore=PR01,PR
161160
partial(_apply_where, dtype=dtype, xp=meta_xp),
162161
cond,
163162
f1,
164-
f2_,
163+
f2,
165164
*args,
166165
fill_value=fill_value,
167166
dtype=dtype,

0 commit comments

Comments
 (0)