Skip to content

Commit dd66ffe

Browse files
authored
Merge pull request numpy#27179 from jorenham/typing/piecewise-paramspec
TYP: Improved ``numpy.piecewise`` type-hints
2 parents f2666ed + 7e7cb50 commit dd66ffe

File tree

3 files changed

+47
-13
lines changed

3 files changed

+47
-13
lines changed

numpy/lib/_function_base_impl.pyi

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
from collections.abc import Sequence, Iterator, Callable, Iterable
22
from typing import (
3+
Concatenate,
34
Literal as L,
45
Any,
6+
ParamSpec,
57
TypeVar,
68
overload,
79
Protocol,
@@ -34,6 +36,7 @@ from numpy._typing import (
3436
_ScalarLike_co,
3537
_DTypeLike,
3638
_ArrayLike,
39+
_ArrayLikeBool_co,
3740
_ArrayLikeInt_co,
3841
_ArrayLikeFloat_co,
3942
_ArrayLikeComplex_co,
@@ -50,6 +53,8 @@ from numpy._core.multiarray import (
5053

5154
_T = TypeVar("_T")
5255
_T_co = TypeVar("_T_co", covariant=True)
56+
# The `{}ss` suffix refers to the Python 3.12 syntax: `**P`
57+
_Pss = ParamSpec("_Pss")
5358
_SCT = TypeVar("_SCT", bound=generic)
5459
_ArrayType = TypeVar("_ArrayType", bound=NDArray[Any])
5560

@@ -180,23 +185,29 @@ def asarray_chkfinite(
180185
order: _OrderKACF = ...,
181186
) -> NDArray[Any]: ...
182187

183-
# TODO: Use PEP 612 `ParamSpec` once mypy supports `Concatenate`
184-
# xref python/mypy#8645
185188
@overload
186189
def piecewise(
187190
x: _ArrayLike[_SCT],
188-
condlist: ArrayLike,
189-
funclist: Sequence[Any | Callable[..., Any]],
190-
*args: Any,
191-
**kw: Any,
191+
condlist: _ArrayLike[bool_] | Sequence[_ArrayLikeBool_co],
192+
funclist: Sequence[
193+
Callable[Concatenate[NDArray[_SCT], _Pss], NDArray[_SCT | Any]]
194+
| _SCT | object
195+
],
196+
/,
197+
*args: _Pss.args,
198+
**kw: _Pss.kwargs,
192199
) -> NDArray[_SCT]: ...
193200
@overload
194201
def piecewise(
195202
x: ArrayLike,
196-
condlist: ArrayLike,
197-
funclist: Sequence[Any | Callable[..., Any]],
198-
*args: Any,
199-
**kw: Any,
203+
condlist: _ArrayLike[bool_] | Sequence[_ArrayLikeBool_co],
204+
funclist: Sequence[
205+
Callable[Concatenate[NDArray[Any], _Pss], NDArray[Any]]
206+
| object
207+
],
208+
/,
209+
*args: _Pss.args,
210+
**kw: _Pss.kwargs,
200211
) -> NDArray[Any]: ...
201212

202213
def select(

numpy/typing/tests/data/fail/lib_function_base.pyi

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,10 @@ AR_c16: npt.NDArray[np.complex128]
88
AR_m: npt.NDArray[np.timedelta64]
99
AR_M: npt.NDArray[np.datetime64]
1010
AR_O: npt.NDArray[np.object_]
11+
AR_b_list: list[npt.NDArray[np.bool]]
1112

12-
def func(a: int) -> None: ...
13+
def fn_none_i(a: None, /) -> npt.NDArray[Any]: ...
14+
def fn_ar_i(a: npt.NDArray[np.float64], posarg: int, /) -> npt.NDArray[Any]: ...
1315

1416
np.average(AR_m) # E: incompatible type
1517
np.select(1, [AR_f8]) # E: incompatible type
@@ -21,6 +23,15 @@ np.place(1, [True], 1.5) # E: incompatible type
2123
np.vectorize(1) # E: incompatible type
2224
np.place(AR_f8, slice(None), 5) # E: incompatible type
2325

26+
np.piecewise(AR_f8, True, [fn_ar_i], 42) # E: No overload variants
27+
# TODO: enable these once mypy actually supports ParamSpec (released in 2021)
28+
# NOTE: pyright correctly reports errors for these (`reportCallIssue`)
29+
# np.piecewise(AR_f8, AR_b_list, [fn_none_i]) # E: No overload variants
30+
# np.piecewise(AR_f8, AR_b_list, [fn_ar_i]) # E: No overload variant
31+
# np.piecewise(AR_f8, AR_b_list, [fn_ar_i], 3.14) # E: No overload variant
32+
# np.piecewise(AR_f8, AR_b_list, [fn_ar_i], 42, None) # E: No overload variant
33+
# np.piecewise(AR_f8, AR_b_list, [fn_ar_i], 42, _=None) # E: No overload variant
34+
2435
np.interp(AR_f8, AR_c16, AR_f8) # E: incompatible type
2536
np.interp(AR_c16, AR_f8, AR_f8) # E: incompatible type
2637
np.interp(AR_f8, AR_f8, AR_f8, period=AR_c16) # E: No overload variant

numpy/typing/tests/data/reveal/lib_function_base.pyi

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,16 @@ AR_b: npt.NDArray[np.bool]
2424
AR_U: npt.NDArray[np.str_]
2525
CHAR_AR_U: np.char.chararray[Any, np.dtype[np.str_]]
2626

27-
def func(*args: Any, **kwargs: Any) -> Any: ...
27+
AR_b_list: list[npt.NDArray[np.bool]]
28+
29+
def func(
30+
a: npt.NDArray[Any],
31+
posarg: bool = ...,
32+
/,
33+
arg: int = ...,
34+
*,
35+
kwarg: str = ...,
36+
) -> npt.NDArray[Any]: ...
2837

2938
assert_type(vectorized_func.pyfunc, Callable[..., Any])
3039
assert_type(vectorized_func.cache, bool)
@@ -65,7 +74,10 @@ assert_type(np.asarray_chkfinite(AR_f8, dtype=np.float64), npt.NDArray[np.float6
6574
assert_type(np.asarray_chkfinite(AR_f8, dtype=float), npt.NDArray[Any])
6675

6776
assert_type(np.piecewise(AR_f8, AR_b, [func]), npt.NDArray[np.float64])
68-
assert_type(np.piecewise(AR_LIKE_f8, AR_b, [func]), npt.NDArray[Any])
77+
assert_type(np.piecewise(AR_f8, AR_b_list, [func]), npt.NDArray[np.float64])
78+
assert_type(np.piecewise(AR_f8, AR_b_list, [func], True, -1, kwarg=''), npt.NDArray[np.float64])
79+
assert_type(np.piecewise(AR_f8, AR_b_list, [func], True, arg=-1, kwarg=''), npt.NDArray[np.float64])
80+
assert_type(np.piecewise(AR_LIKE_f8, AR_b_list, [func]), npt.NDArray[Any])
6981

7082
assert_type(np.select([AR_f8], [AR_f8]), npt.NDArray[Any])
7183

0 commit comments

Comments
 (0)