Skip to content

Commit 54c298d

Browse files
committed
TYP: Improved numpy.frompyfunc type hints
Changed the ``frompyfunc`` signature and added overloads. It now returns a specialized ``ufunc`` type for integer ``nin`` and ``nout``.
1 parent 8420a7e commit 54c298d

File tree

3 files changed

+598
-10
lines changed

3 files changed

+598
-10
lines changed

numpy/_core/multiarray.pyi

Lines changed: 78 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,13 @@ from numpy._typing import (
7979
_FloatLike_co,
8080
_TD64Like_co,
8181
)
82+
from numpy._typing._ufunc import (
83+
_2PTuple,
84+
_PyFunc_Nin1_Nout1,
85+
_PyFunc_Nin2_Nout1,
86+
_PyFunc_Nin3P_Nout1,
87+
_PyFunc_Nin1P_Nout2P,
88+
)
8289

8390
_T_co = TypeVar("_T_co", covariant=True)
8491
_T_contra = TypeVar("_T_contra", contravariant=True)
@@ -89,8 +96,12 @@ _ArrayType_co = TypeVar(
8996
bound=ndarray[Any, Any],
9097
covariant=True,
9198
)
92-
_SizeType = TypeVar("_SizeType", bound=int)
99+
_ReturnType = TypeVar("_ReturnType")
100+
_IDType = TypeVar("_IDType")
101+
_Nin = TypeVar("_Nin", bound=int)
102+
_Nout = TypeVar("_Nout", bound=int)
93103

104+
_SizeType = TypeVar("_SizeType", bound=int)
94105
_1DArray: TypeAlias = ndarray[tuple[_SizeType], dtype[_SCT]]
95106

96107
# Valid time units
@@ -682,12 +693,77 @@ def fromstring(
682693
like: None | _SupportsArrayFunc = ...,
683694
) -> NDArray[Any]: ...
684695

696+
@overload
697+
def frompyfunc( # type: ignore[overload-overlap]
698+
func: Callable[[Any], _ReturnType], /,
699+
nin: L[1],
700+
nout: L[1],
701+
*,
702+
identity: None = ...,
703+
) -> _PyFunc_Nin1_Nout1[_ReturnType, None]: ...
704+
@overload
705+
def frompyfunc( # type: ignore[overload-overlap]
706+
func: Callable[[Any], _ReturnType], /,
707+
nin: L[1],
708+
nout: L[1],
709+
*,
710+
identity: _IDType,
711+
) -> _PyFunc_Nin1_Nout1[_ReturnType, _IDType]: ...
712+
@overload
713+
def frompyfunc( # type: ignore[overload-overlap]
714+
func: Callable[[Any, Any], _ReturnType], /,
715+
nin: L[2],
716+
nout: L[1],
717+
*,
718+
identity: None = ...,
719+
) -> _PyFunc_Nin2_Nout1[_ReturnType, None]: ...
720+
@overload
721+
def frompyfunc( # type: ignore[overload-overlap]
722+
func: Callable[[Any, Any], _ReturnType], /,
723+
nin: L[2],
724+
nout: L[1],
725+
*,
726+
identity: _IDType,
727+
) -> _PyFunc_Nin2_Nout1[_ReturnType, _IDType]: ...
728+
@overload
729+
def frompyfunc( # type: ignore[overload-overlap]
730+
func: Callable[..., _ReturnType], /,
731+
nin: _Nin,
732+
nout: L[1],
733+
*,
734+
identity: None = ...,
735+
) -> _PyFunc_Nin3P_Nout1[_ReturnType, None, _Nin]: ...
736+
@overload
737+
def frompyfunc( # type: ignore[overload-overlap]
738+
func: Callable[..., _ReturnType], /,
739+
nin: _Nin,
740+
nout: L[1],
741+
*,
742+
identity: _IDType,
743+
) -> _PyFunc_Nin3P_Nout1[_ReturnType, _IDType, _Nin]: ...
744+
@overload
745+
def frompyfunc(
746+
func: Callable[..., _2PTuple[_ReturnType]], /,
747+
nin: _Nin,
748+
nout: _Nout,
749+
*,
750+
identity: None = ...,
751+
) -> _PyFunc_Nin1P_Nout2P[_ReturnType, None, _Nin, _Nout]: ...
752+
@overload
753+
def frompyfunc(
754+
func: Callable[..., _2PTuple[_ReturnType]], /,
755+
nin: _Nin,
756+
nout: _Nout,
757+
*,
758+
identity: _IDType,
759+
) -> _PyFunc_Nin1P_Nout2P[_ReturnType, _IDType, _Nin, _Nout]: ...
760+
@overload
685761
def frompyfunc(
686762
func: Callable[..., Any], /,
687763
nin: SupportsIndex,
688764
nout: SupportsIndex,
689765
*,
690-
identity: Any = ...,
766+
identity: None | object = ...,
691767
) -> ufunc: ...
692768

693769
@overload

0 commit comments

Comments
 (0)