Skip to content

Commit 873b8ea

Browse files
authored
✨ dtype typar default for ndarray (#555)
1 parent eee4a2c commit 873b8ea

File tree

2 files changed

+47
-36
lines changed

2 files changed

+47
-36
lines changed

src/numpy-stubs/@test/runtime/legacy/literal.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
CF = frozenset({None, "C", "F"})
1919

2020
order_list: list[tuple[frozenset[str | None], Callable[..., Any]]] = [
21-
(KACF, partial(np.ndarray, 1)),
21+
(KACF, partial(np.ndarray.__call__, 1)),
2222
(KACF, AR.tobytes),
2323
(KACF, partial(AR.astype, int)),
2424
(KACF, AR.copy),

src/numpy-stubs/__init__.pyi

Lines changed: 46 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -619,7 +619,7 @@ _RealT_co = TypeVar("_RealT_co", covariant=True)
619619
_ImagT_co = TypeVar("_ImagT_co", covariant=True)
620620

621621
_DTypeT = TypeVar("_DTypeT", bound=dtype)
622-
_DTypeT_co = TypeVar("_DTypeT_co", bound=dtype, covariant=True)
622+
_DTypeT_co = TypeVar("_DTypeT_co", bound=dtype, default=dtype, covariant=True)
623623
_FlexDTypeT = TypeVar("_FlexDTypeT", bound=dtype[flexible])
624624

625625
_ArrayT = TypeVar("_ArrayT", bound=_nt.Array)
@@ -628,6 +628,7 @@ _NumericArrayT = TypeVar("_NumericArrayT", bound=_nt.Array[number | timedelta64
628628

629629
_ShapeT = TypeVar("_ShapeT", bound=_nt.Shape)
630630
_ShapeT_co = TypeVar("_ShapeT_co", bound=_nt.Shape, covariant=True)
631+
_ShapeT0_co = TypeVar("_ShapeT0_co", bound=_nt.Shape, default=_nt.Shape, covariant=True)
631632
_Shape1NDT = TypeVar("_Shape1NDT", bound=_nt.Shape1N)
632633

633634
_ScalarT = TypeVar("_ScalarT", bound=generic)
@@ -998,9 +999,11 @@ _HasTypeWithReal: TypeAlias = _HasType[_HasReal[_T]]
998999
_HasTypeWithImag: TypeAlias = _HasType[_HasImag[_T]]
9991000

10001001
@type_check_only
1001-
class _HasDType(Protocol[_T_co]):
1002+
class _HasDType(Protocol[_T_co, _ShapeT0_co]):
10021003
@property
10031004
def dtype(self, /) -> _T_co: ...
1005+
@property
1006+
def shape(self, /) -> _ShapeT0_co: ...
10041007

10051008
_HasDTypeWithItem: TypeAlias = _HasDType[_HasTypeWithItem[_T]]
10061009
_HasDTypeWithReal: TypeAlias = _HasDType[_HasTypeWithReal[_T]]
@@ -2117,16 +2120,16 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeT_co, _DTypeT_co]):
21172120
def __ge__(self, other: _ArrayLikeObject_co, /) -> _nt.Array[bool_]: ...
21182121

21192122
#
2120-
def __abs__(self: _HasDType[_HasType[SupportsAbs[_ScalarT]]], /) -> _nt.Array[_ScalarT, _ShapeT_co]: ...
2123+
def __abs__(self: _HasDType[_HasType[SupportsAbs[_ScalarT]], _ShapeT], /) -> _nt.Array[_ScalarT, _ShapeT]: ...
21212124
def __neg__(self: _NumericArrayT, /) -> _NumericArrayT: ... # noqa: PYI019
21222125
def __pos__(self: _NumericArrayT, /) -> _NumericArrayT: ... # noqa: PYI019
21232126
def __invert__(self: _IntegralArrayT, /) -> _IntegralArrayT: ... # noqa: PYI019
21242127

21252128
#
21262129
@overload
2127-
def __add__(self: _nt.Array[_ScalarT], x: _nt.Casts[_ScalarT], /) -> _nt.Array[_ScalarT]: ... # type: ignore[overload-overlap]
2130+
def __add__(self: _nt.Array[_ScalarT], x: _nt.Casts[_ScalarT], /) -> _nt.Array[_ScalarT]: ...
21282131
@overload
2129-
def __add__(self: _nt.Array[_SelfScalarT], x: _nt.CastsWith[_SelfScalarT, _ScalarT], /) -> _nt.Array[_ScalarT]: ... # type: ignore[overload-overlap]
2132+
def __add__(self: _nt.Array[_SelfScalarT], x: _nt.CastsWith[_SelfScalarT, _ScalarT], /) -> _nt.Array[_ScalarT]: ...
21302133
@overload
21312134
def __add__(self: _nt.CastsWithBuiltin[_T, _ScalarT], x: _nt.SequenceND[_T], /) -> _nt.Array[_ScalarT]: ...
21322135
@overload
@@ -2138,21 +2141,21 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeT_co, _DTypeT_co]):
21382141
@overload
21392142
def __add__(self: _nt.Array[datetime64], x: _nt.CoTimeDelta_nd, /) -> _nt.Array[datetime64]: ...
21402143
@overload
2141-
def __add__(self: _nt.Array[_nt.co_timedelta], x: _nt.ToDateTime_nd, /) -> _nt.Array[datetime64]: ...
2144+
def __add__(self: _nt.Array[_nt.co_timedelta, Any], x: _nt.ToDateTime_nd, /) -> _nt.Array[datetime64]: ...
21422145
@overload
2143-
def __add__(self: _nt.Array[object_], x: object, /) -> _nt.Array[object_]: ... # type: ignore[overload-cannot-match]
2146+
def __add__(self: _nt.Array[object_, Any], x: object, /) -> _nt.Array[object_]: ...
21442147
@overload
2145-
def __add__(self: _nt.Array[str_], x: _nt.ToString_nd[_T], /) -> _nt.StringArrayND[_T]: ...
2148+
def __add__(self: _nt.Array[str_, Any], x: _nt.ToString_nd[_T], /) -> _nt.StringArrayND[_T]: ...
21462149
@overload
21472150
def __add__(self: _nt.StringArrayND[_T], x: _nt.ToString_nd[_T] | _nt.ToStr_nd, /) -> _nt.StringArrayND[_T]: ...
21482151
@overload
21492152
def __add__(self: _nt.Array[generic[_T]], x: _nt.Sequence1ND[_nt.op.CanRAdd[_T]], /) -> _nt.Array[Incomplete]: ...
21502153

21512154
#
21522155
@overload
2153-
def __radd__(self: _nt.Array[_ScalarT], x: _nt.Casts[_ScalarT], /) -> _nt.Array[_ScalarT]: ... # type: ignore[overload-overlap]
2156+
def __radd__(self: _nt.Array[_ScalarT], x: _nt.Casts[_ScalarT], /) -> _nt.Array[_ScalarT]: ...
21542157
@overload
2155-
def __radd__(self: _nt.Array[_SelfScalarT], x: _nt.CastsWith[_SelfScalarT, _ScalarT], /) -> _nt.Array[_ScalarT]: ... # type: ignore[overload-overlap]
2158+
def __radd__(self: _nt.Array[_SelfScalarT], x: _nt.CastsWith[_SelfScalarT, _ScalarT], /) -> _nt.Array[_ScalarT]: ...
21562159
@overload
21572160
def __radd__(self: _nt.CastsWithBuiltin[_T, _ScalarT], x: _nt.SequenceND[_T], /) -> _nt.Array[_ScalarT]: ...
21582161
@overload
@@ -2164,11 +2167,11 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeT_co, _DTypeT_co]):
21642167
@overload
21652168
def __radd__(self: _nt.Array[datetime64], x: _nt.CoTimeDelta_nd, /) -> _nt.Array[datetime64]: ...
21662169
@overload
2167-
def __radd__(self: _nt.Array[_nt.co_timedelta], x: _nt.ToDateTime_nd, /) -> _nt.Array[datetime64]: ...
2170+
def __radd__(self: _nt.Array[_nt.co_timedelta, Any], x: _nt.ToDateTime_nd, /) -> _nt.Array[datetime64]: ...
21682171
@overload
2169-
def __radd__(self: _nt.Array[object_], x: object, /) -> _nt.Array[object_]: ... # type: ignore[overload-cannot-match]
2172+
def __radd__(self: _nt.Array[object_, Any], x: object, /) -> _nt.Array[object_]: ...
21702173
@overload
2171-
def __radd__(self: _nt.Array[str_], x: _nt.ToString_nd[_T], /) -> _nt.StringArrayND[_T]: ...
2174+
def __radd__(self: _nt.Array[str_, Any], x: _nt.ToString_nd[_T], /) -> _nt.StringArrayND[_T]: ...
21722175
@overload
21732176
def __radd__(self: _nt.StringArrayND[_T], x: _nt.ToString_nd[_T] | _nt.ToStr_nd, /) -> _nt.StringArrayND[_T]: ...
21742177
@overload
@@ -2268,7 +2271,7 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeT_co, _DTypeT_co]):
22682271
@overload
22692272
def __mul__(self: _nt.Array[_CoNumberT], x: _nt.Casts[_CoNumberT], /) -> _nt.Array[_CoNumberT]: ...
22702273
@overload
2271-
def __mul__(self: _nt.Array[_SelfScalarT], x: _nt.CastsWith[_SelfScalarT, _ScalarT], /) -> _nt.Array[_ScalarT]: ... # type: ignore[overload-overlap]
2274+
def __mul__(self: _nt.Array[_SelfScalarT], x: _nt.CastsWith[_SelfScalarT, _ScalarT], /) -> _nt.Array[_ScalarT]: ...
22722275
@overload
22732276
def __mul__(self: _nt.CastsWithBuiltin[_T, _ScalarT], x: _nt.SequenceND[_T], /) -> _nt.Array[_ScalarT]: ...
22742277
@overload
@@ -2278,11 +2281,11 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeT_co, _DTypeT_co]):
22782281
@overload
22792282
def __mul__(self: _nt.CastsWithComplex[_ScalarT], x: _PyComplexND, /) -> _nt.Array[_ScalarT]: ...
22802283
@overload
2281-
def __mul__(self: _nt.Array[timedelta64], x: _nt.ToFloating_nd, /) -> _nt.Array[timedelta64]: ...
2284+
def __mul__(self: _nt.Array[timedelta64, Any], x: _nt.ToFloating_nd, /) -> _nt.Array[timedelta64]: ...
22822285
@overload
2283-
def __mul__(self: _nt.Array[object_], x: object, /) -> _nt.Array[object_]: ... # type: ignore[overload-cannot-match]
2286+
def __mul__(self: _nt.Array[object_, Any], x: object, /) -> _nt.Array[object_]: ...
22842287
@overload
2285-
def __mul__(self: _nt.Array[integer], x: _nt.ToString_nd, /) -> _nt.StringArrayND[_T]: ...
2288+
def __mul__(self: _nt.Array[integer, Any], x: _nt.ToString_nd, /) -> _nt.StringArrayND[_T]: ...
22862289
@overload
22872290
def __mul__(self: _nt.StringArrayND[_T], x: _nt.ToInteger_nd, /) -> _nt.StringArrayND[_T]: ...
22882291
@overload
@@ -2292,7 +2295,7 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeT_co, _DTypeT_co]):
22922295
@overload
22932296
def __rmul__(self: _nt.Array[_CoNumberT], x: _nt.Casts[_CoNumberT], /) -> _nt.Array[_CoNumberT]: ...
22942297
@overload
2295-
def __rmul__(self: _nt.Array[_SelfScalarT], x: _nt.CastsWith[_SelfScalarT, _ScalarT], /) -> _nt.Array[_ScalarT]: ... # type: ignore[overload-overlap]
2298+
def __rmul__(self: _nt.Array[_SelfScalarT], x: _nt.CastsWith[_SelfScalarT, _ScalarT], /) -> _nt.Array[_ScalarT]: ...
22962299
@overload
22972300
def __rmul__(self: _nt.CastsWithBuiltin[_T, _ScalarT], x: _nt.SequenceND[_T], /) -> _nt.Array[_ScalarT]: ...
22982301
@overload
@@ -2302,11 +2305,11 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeT_co, _DTypeT_co]):
23022305
@overload
23032306
def __rmul__(self: _nt.CastsWithComplex[_ScalarT], x: _PyComplexND, /) -> _nt.Array[_ScalarT]: ...
23042307
@overload
2305-
def __rmul__(self: _nt.Array[timedelta64], x: _nt.ToFloating_nd, /) -> _nt.Array[timedelta64]: ...
2308+
def __rmul__(self: _nt.Array[timedelta64, Any], x: _nt.ToFloating_nd, /) -> _nt.Array[timedelta64]: ...
23062309
@overload
2307-
def __rmul__(self: _nt.Array[object_], x: object, /) -> _nt.Array[object_]: ... # type: ignore[overload-cannot-match]
2310+
def __rmul__(self: _nt.Array[object_, Any], x: object, /) -> _nt.Array[object_]: ...
23082311
@overload
2309-
def __rmul__(self: _nt.Array[integer], x: _nt.ToString_nd, /) -> _nt.StringArrayND[_T]: ...
2312+
def __rmul__(self: _nt.Array[integer, Any], x: _nt.ToString_nd, /) -> _nt.StringArrayND[_T]: ...
23102313
@overload
23112314
def __rmul__(self: _nt.StringArrayND[_T], x: _nt.ToInteger_nd, /) -> _nt.StringArrayND[_T]: ...
23122315
@overload
@@ -2666,47 +2669,55 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeT_co, _DTypeT_co]):
26662669

26672670
#
26682671
@overload
2669-
def __divmod__(self: _nt.Array[bool_], x: _nt.ToBool_nd, /) -> _2Tuple[_nt.Array[int8]]: ...
2672+
def __divmod__(self: _nt.Array[bool_, _nt.Shape], x: _nt.ToBool_nd, /) -> _2Tuple[_nt.Array[int8]]: ...
26702673
@overload
26712674
def __divmod__(
2672-
self: _nt.Array[_RealNumberT], x: _nt.Casts[_RealNumberT] | _nt.ToBool_nd, /
2675+
self: _nt.Array[_RealNumberT, _nt.Shape], x: _nt.Casts[_RealNumberT, _nt.Shape] | _nt.ToBool_nd, /
26732676
) -> _2Tuple[_nt.Array[_RealNumberT]]: ...
26742677
@overload
26752678
def __divmod__(
2676-
self: _nt.Array[_CoFloatingT], x: _nt.CastsWith[_CoFloatingT, _RealScalarT], /
2679+
self: _nt.Array[_CoFloatingT, _nt.Shape], x: _nt.CastsWith[_CoFloatingT, _RealScalarT, _nt.Shape], /
26772680
) -> _2Tuple[_nt.Array[_RealScalarT]]: ...
26782681
@overload
2679-
def __divmod__(self: _nt.CastsWithInt[_RealScalarT], x: _PyIntND, /) -> _2Tuple[_nt.Array[_RealScalarT]]: ...
2682+
def __divmod__(
2683+
self: _nt.CastsWithInt[_RealScalarT, _nt.Shape], x: _PyIntND, /
2684+
) -> _2Tuple[_nt.Array[_RealScalarT]]: ...
26802685
@overload
2681-
def __divmod__(self: _nt.CastsWithFloat[_RealScalarT], x: _PyFloatND, /) -> _2Tuple[_nt.Array[_RealScalarT]]: ...
2686+
def __divmod__(
2687+
self: _nt.CastsWithFloat[_RealScalarT, _nt.Shape], x: _PyFloatND, /
2688+
) -> _2Tuple[_nt.Array[_RealScalarT]]: ...
26822689
@overload
26832690
def __divmod__(
2684-
self: _nt.Array[timedelta64], x: _nt.ToTimeDelta_nd, /
2691+
self: _nt.Array[timedelta64, _nt.Shape], x: _nt.ToTimeDelta_nd, /
26852692
) -> tuple[_nt.Array[int64], _nt.Array[timedelta64]]: ...
26862693
@overload
2687-
def __divmod__(self: _nt.Array[object_], x: object, /) -> _2Tuple[_nt.Array[object_]]: ...
2694+
def __divmod__(self: _nt.Array[object_, _nt.Shape], x: object, /) -> _2Tuple[_nt.Array[object_]]: ...
26882695

26892696
#
26902697
@overload
2691-
def __rdivmod__(self: _nt.Array[bool_], x: _nt.ToBool_nd, /) -> _2Tuple[_nt.Array[int8]]: ...
2698+
def __rdivmod__(self: _nt.Array[bool_, _nt.Shape], x: _nt.ToBool_nd, /) -> _2Tuple[_nt.Array[int8]]: ...
26922699
@overload
26932700
def __rdivmod__(
2694-
self: _nt.Array[_RealNumberT], x: _nt.Casts[_RealNumberT] | _nt.ToBool_nd, /
2701+
self: _nt.Array[_RealNumberT, _nt.Shape], x: _nt.Casts[_RealNumberT, _nt.Shape] | _nt.ToBool_nd, /
26952702
) -> _2Tuple[_nt.Array[_RealNumberT]]: ...
26962703
@overload
26972704
def __rdivmod__(
2698-
self: _nt.Array[_CoFloatingT], x: _nt.CastsWith[_CoFloatingT, _RealScalarT], /
2705+
self: _nt.Array[_CoFloatingT, _nt.Shape], x: _nt.CastsWith[_CoFloatingT, _RealScalarT, _nt.Shape], /
26992706
) -> _2Tuple[_nt.Array[_RealScalarT]]: ...
27002707
@overload
2701-
def __rdivmod__(self: _nt.CastsWithInt[_RealScalarT], x: _PyIntND, /) -> _2Tuple[_nt.Array[_RealScalarT]]: ...
2708+
def __rdivmod__(
2709+
self: _nt.CastsWithInt[_RealScalarT, _nt.Shape], x: _PyIntND, /
2710+
) -> _2Tuple[_nt.Array[_RealScalarT]]: ...
27022711
@overload
2703-
def __rdivmod__(self: _nt.CastsWithFloat[_RealScalarT], x: _PyFloatND, /) -> _2Tuple[_nt.Array[_RealScalarT]]: ...
2712+
def __rdivmod__(
2713+
self: _nt.CastsWithFloat[_RealScalarT, _nt.Shape], x: _PyFloatND, /
2714+
) -> _2Tuple[_nt.Array[_RealScalarT]]: ...
27042715
@overload
27052716
def __rdivmod__(
2706-
self: _nt.Array[timedelta64], x: _nt.ToTimeDelta_nd, /
2717+
self: _nt.Array[timedelta64, _nt.Shape], x: _nt.ToTimeDelta_nd, /
27072718
) -> tuple[_nt.Array[int64], _nt.Array[timedelta64]]: ...
27082719
@overload
2709-
def __rdivmod__(self: _nt.Array[object_], x: object, /) -> _2Tuple[_nt.Array[object_]]: ...
2720+
def __rdivmod__(self: _nt.Array[object_, _nt.Shape], x: object, /) -> _2Tuple[_nt.Array[object_]]: ...
27102721

27112722
#
27122723
@overload

0 commit comments

Comments
 (0)