Skip to content

Commit 1b4bd2a

Browse files
authored
Merge pull request numpy#27177 from jorenham/typing/arange-shape-type
TYP: 1-d ``numpy.arange`` return shape-type
2 parents dd66ffe + 4ef2def commit 1b4bd2a

File tree

2 files changed

+28
-24
lines changed

2 files changed

+28
-24
lines changed

numpy/_core/multiarray.pyi

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ from collections.abc import Sequence, Callable, Iterable
66
from typing import (
77
Literal as L,
88
Any,
9+
TypeAlias,
910
overload,
1011
TypeVar,
1112
SupportsIndex,
@@ -88,6 +89,9 @@ _ArrayType_co = TypeVar(
8889
bound=ndarray[Any, Any],
8990
covariant=True,
9091
)
92+
_SizeType = TypeVar("_SizeType", bound=int)
93+
94+
_1DArray: TypeAlias = ndarray[tuple[_SizeType], dtype[_SCT]]
9195

9296
# Valid time units
9397
_UnitKind = L[
@@ -769,7 +773,7 @@ def arange( # type: ignore[misc]
769773
dtype: None = ...,
770774
device: None | L["cpu"] = ...,
771775
like: None | _SupportsArrayFunc = ...,
772-
) -> NDArray[signedinteger[Any]]: ...
776+
) -> _1DArray[int, signedinteger[Any]]: ...
773777
@overload
774778
def arange( # type: ignore[misc]
775779
start: _IntLike_co,
@@ -779,15 +783,15 @@ def arange( # type: ignore[misc]
779783
*,
780784
device: None | L["cpu"] = ...,
781785
like: None | _SupportsArrayFunc = ...,
782-
) -> NDArray[signedinteger[Any]]: ...
786+
) -> _1DArray[int, signedinteger[Any]]: ...
783787
@overload
784788
def arange( # type: ignore[misc]
785789
stop: _FloatLike_co,
786790
/, *,
787791
dtype: None = ...,
788792
device: None | L["cpu"] = ...,
789793
like: None | _SupportsArrayFunc = ...,
790-
) -> NDArray[floating[Any]]: ...
794+
) -> _1DArray[int, floating[Any]]: ...
791795
@overload
792796
def arange( # type: ignore[misc]
793797
start: _FloatLike_co,
@@ -797,15 +801,15 @@ def arange( # type: ignore[misc]
797801
*,
798802
device: None | L["cpu"] = ...,
799803
like: None | _SupportsArrayFunc = ...,
800-
) -> NDArray[floating[Any]]: ...
804+
) -> _1DArray[int, floating[Any]]: ...
801805
@overload
802806
def arange(
803807
stop: _TD64Like_co,
804808
/, *,
805809
dtype: None = ...,
806810
device: None | L["cpu"] = ...,
807811
like: None | _SupportsArrayFunc = ...,
808-
) -> NDArray[timedelta64]: ...
812+
) -> _1DArray[int, timedelta64]: ...
809813
@overload
810814
def arange(
811815
start: _TD64Like_co,
@@ -815,7 +819,7 @@ def arange(
815819
*,
816820
device: None | L["cpu"] = ...,
817821
like: None | _SupportsArrayFunc = ...,
818-
) -> NDArray[timedelta64]: ...
822+
) -> _1DArray[int, timedelta64]: ...
819823
@overload
820824
def arange( # both start and stop must always be specified for datetime64
821825
start: datetime64,
@@ -825,15 +829,15 @@ def arange( # both start and stop must always be specified for datetime64
825829
*,
826830
device: None | L["cpu"] = ...,
827831
like: None | _SupportsArrayFunc = ...,
828-
) -> NDArray[datetime64]: ...
832+
) -> _1DArray[int, datetime64]: ...
829833
@overload
830834
def arange(
831835
stop: Any,
832836
/, *,
833837
dtype: _DTypeLike[_SCT],
834838
device: None | L["cpu"] = ...,
835839
like: None | _SupportsArrayFunc = ...,
836-
) -> NDArray[_SCT]: ...
840+
) -> _1DArray[int, _SCT]: ...
837841
@overload
838842
def arange(
839843
start: Any,
@@ -843,15 +847,15 @@ def arange(
843847
*,
844848
device: None | L["cpu"] = ...,
845849
like: None | _SupportsArrayFunc = ...,
846-
) -> NDArray[_SCT]: ...
850+
) -> _1DArray[int, _SCT]: ...
847851
@overload
848852
def arange(
849853
stop: Any, /,
850854
*,
851855
dtype: DTypeLike,
852856
device: None | L["cpu"] = ...,
853857
like: None | _SupportsArrayFunc = ...,
854-
) -> NDArray[Any]: ...
858+
) -> _1DArray[int, Any]: ...
855859
@overload
856860
def arange(
857861
start: Any,
@@ -861,7 +865,7 @@ def arange(
861865
*,
862866
device: None | L["cpu"] = ...,
863867
like: None | _SupportsArrayFunc = ...,
864-
) -> NDArray[Any]: ...
868+
) -> _1DArray[int, Any]: ...
865869

866870
def datetime_data(
867871
dtype: str | _DTypeLike[datetime64] | _DTypeLike[timedelta64], /,

numpy/typing/tests/data/reveal/array_constructors.pyi

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import sys
2-
from typing import Any, TypeVar
2+
from typing import Any, Literal as L, TypeVar
33
from pathlib import Path
44
from collections import deque
55

@@ -105,18 +105,18 @@ assert_type(np.frombuffer(A), npt.NDArray[np.float64])
105105
assert_type(np.frombuffer(A, dtype=np.int64), npt.NDArray[np.int64])
106106
assert_type(np.frombuffer(A, dtype="c16"), npt.NDArray[Any])
107107

108-
assert_type(np.arange(False, True), npt.NDArray[np.signedinteger[Any]])
109-
assert_type(np.arange(10), npt.NDArray[np.signedinteger[Any]])
110-
assert_type(np.arange(0, 10, step=2), npt.NDArray[np.signedinteger[Any]])
111-
assert_type(np.arange(10.0), npt.NDArray[np.floating[Any]])
112-
assert_type(np.arange(start=0, stop=10.0), npt.NDArray[np.floating[Any]])
113-
assert_type(np.arange(np.timedelta64(0)), npt.NDArray[np.timedelta64])
114-
assert_type(np.arange(0, np.timedelta64(10)), npt.NDArray[np.timedelta64])
115-
assert_type(np.arange(np.datetime64("0"), np.datetime64("10")), npt.NDArray[np.datetime64])
116-
assert_type(np.arange(10, dtype=np.float64), npt.NDArray[np.float64])
117-
assert_type(np.arange(0, 10, step=2, dtype=np.int16), npt.NDArray[np.int16])
118-
assert_type(np.arange(10, dtype=int), npt.NDArray[Any])
119-
assert_type(np.arange(0, 10, dtype="f8"), npt.NDArray[Any])
108+
assert_type(np.arange(False, True), np.ndarray[tuple[int], np.dtype[np.signedinteger[Any]]])
109+
assert_type(np.arange(10), np.ndarray[tuple[int], np.dtype[np.signedinteger[Any]]])
110+
assert_type(np.arange(0, 10, step=2), np.ndarray[tuple[int], np.dtype[np.signedinteger[Any]]])
111+
assert_type(np.arange(10.0), np.ndarray[tuple[int], np.dtype[np.floating[Any]]])
112+
assert_type(np.arange(start=0, stop=10.0), np.ndarray[tuple[int], np.dtype[np.floating[Any]]])
113+
assert_type(np.arange(np.timedelta64(0)), np.ndarray[tuple[int], np.dtype[np.timedelta64]])
114+
assert_type(np.arange(0, np.timedelta64(10)), np.ndarray[tuple[int], np.dtype[np.timedelta64]])
115+
assert_type(np.arange(np.datetime64("0"), np.datetime64("10")), np.ndarray[tuple[int], np.dtype[np.datetime64]])
116+
assert_type(np.arange(10, dtype=np.float64), np.ndarray[tuple[int], np.dtype[np.float64]])
117+
assert_type(np.arange(0, 10, step=2, dtype=np.int16), np.ndarray[tuple[int], np.dtype[np.int16]])
118+
assert_type(np.arange(10, dtype=int), np.ndarray[tuple[int], np.dtype[Any]])
119+
assert_type(np.arange(0, 10, dtype="f8"), np.ndarray[tuple[int], np.dtype[Any]])
120120

121121
assert_type(np.require(A), npt.NDArray[np.float64])
122122
assert_type(np.require(B), SubClass[np.float64])

0 commit comments

Comments
 (0)