Skip to content

Commit a533d1a

Browse files
committed
💥 use Rank* in the ndarray shape-typed methods
1 parent 5b9842f commit a533d1a

File tree

1 file changed

+50
-28
lines changed

1 file changed

+50
-28
lines changed

‎src/numpy-stubs/__init__.pyi

Lines changed: 50 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -629,7 +629,7 @@ _NumericArrayT = TypeVar("_NumericArrayT", bound=NDArray[number | timedelta64 |
629629

630630
_ShapeT = TypeVar("_ShapeT", bound=_nt.Shape)
631631
_ShapeT_co = TypeVar("_ShapeT_co", bound=_nt.Shape, covariant=True)
632-
_ShapeT_1nd = TypeVar("_ShapeT_1nd", bound=_nt.Shape1N)
632+
_Shape1NDT = TypeVar("_Shape1NDT", bound=_nt.Shape1N)
633633

634634
_ScalarT = TypeVar("_ScalarT", bound=generic)
635635
_SelfScalarT = TypeVar("_SelfScalarT", bound=generic)
@@ -2946,19 +2946,23 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeT_co, _DTypeT_co]):
29462946
@overload
29472947
def diagonal(
29482948
self: ndarray[_nt.Shape2, _DTypeT], /, offset: CanIndex = 0, axis1: CanIndex = 0, axis2: CanIndex = 1
2949-
) -> ndarray[_nt.Shape1, _DTypeT]: ...
2949+
) -> ndarray[_nt.Rank1, _DTypeT]: ...
29502950
@overload
29512951
def diagonal(
29522952
self: ndarray[_nt.Shape3, _DTypeT], /, offset: CanIndex = 0, axis1: CanIndex = 0, axis2: CanIndex = 1
2953-
) -> ndarray[_nt.Shape2, _DTypeT]: ...
2953+
) -> ndarray[_nt.Rank2, _DTypeT]: ...
29542954
@overload
29552955
def diagonal(
29562956
self: ndarray[_nt.Shape4, _DTypeT], /, offset: CanIndex = 0, axis1: CanIndex = 0, axis2: CanIndex = 1
2957-
) -> ndarray[_nt.Shape3, _DTypeT]: ...
2957+
) -> ndarray[_nt.Rank3, _DTypeT]: ...
2958+
@overload
2959+
def diagonal(
2960+
self: ndarray[_nt.Shape4N, _DTypeT], /, offset: CanIndex = 0, axis1: CanIndex = 0, axis2: CanIndex = 1
2961+
) -> ndarray[_nt.Rank3N, _DTypeT]: ...
29582962
@overload
29592963
def diagonal(
29602964
self: ndarray[_nt.Shape, _DTypeT], /, offset: CanIndex = 0, axis1: CanIndex = 0, axis2: CanIndex = 1
2961-
) -> ndarray[_nt.Shape, _DTypeT]: ...
2965+
) -> ndarray[Any, _DTypeT]: ...
29622966

29632967
#
29642968
@overload
@@ -3290,39 +3294,39 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeT_co, _DTypeT_co]):
32903294
) -> ndarray[_AnyShapeT, _DTypeT]: ...
32913295

32923296
#
3293-
def flatten(self, /, order: _OrderKACF = "C") -> ndarray[_nt.Shape1, _DTypeT_co]: ...
3294-
def ravel(self, /, order: _OrderKACF = "C") -> ndarray[_nt.Shape1, _DTypeT_co]: ...
3297+
def flatten(self, /, order: _OrderKACF = "C") -> ndarray[_nt.Rank1, _DTypeT_co]: ...
3298+
def ravel(self, /, order: _OrderKACF = "C") -> ndarray[_nt.Rank1, _DTypeT_co]: ...
32953299

32963300
#
32973301
@overload # (None)
32983302
def reshape(self, shape: None, /, *, order: _OrderACF = "C", copy: py_bool | None = None) -> Self: ...
32993303
@overload # (empty_sequence)
33003304
def reshape( # type: ignore[overload-overlap] # mypy false positive
33013305
self,
3302-
shape: Sequence[Never],
3306+
shape: Sequence[Never] | _nt.Shape0,
33033307
/,
33043308
*,
33053309
order: _OrderACF = "C",
33063310
copy: py_bool | None = None,
3307-
) -> ndarray[_nt.Shape0, _DTypeT_co]: ...
3308-
@overload # (() | (int) | (int, int) | ....) # up to 8-d
3311+
) -> ndarray[_nt.Rank0, _DTypeT_co]: ...
3312+
@overload # (index)
33093313
def reshape(
33103314
self,
3311-
shape: _AnyShapeT,
3315+
size1: CanIndex | _nt.Shape1,
33123316
/,
33133317
*,
33143318
order: _OrderACF = "C",
33153319
copy: py_bool | None = None,
3316-
) -> ndarray[_AnyShapeT, _DTypeT_co]: ...
3317-
@overload # (index)
3320+
) -> ndarray[_nt.Rank1, _DTypeT_co]: ...
3321+
@overload # (index, index)
33183322
def reshape(
33193323
self,
3320-
size1: CanIndex,
3324+
size1: _nt.Shape2,
33213325
/,
33223326
*,
33233327
order: _OrderACF = "C",
33243328
copy: py_bool | None = None,
3325-
) -> ndarray[_nt.Shape1, _DTypeT_co]: ...
3329+
) -> ndarray[_nt.Rank2, _DTypeT_co]: ...
33263330
@overload # (index, index)
33273331
def reshape(
33283332
self,
@@ -3332,7 +3336,16 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeT_co, _DTypeT_co]):
33323336
*,
33333337
order: _OrderACF = "C",
33343338
copy: py_bool | None = None,
3335-
) -> ndarray[_nt.Shape2, _DTypeT_co]: ...
3339+
) -> ndarray[_nt.Rank2, _DTypeT_co]: ...
3340+
@overload # (index, index, index)
3341+
def reshape(
3342+
self,
3343+
size1: _nt.Shape3,
3344+
/,
3345+
*,
3346+
order: _OrderACF = "C",
3347+
copy: py_bool | None = None,
3348+
) -> ndarray[_nt.Rank3, _DTypeT_co]: ...
33363349
@overload # (index, index, index)
33373350
def reshape(
33383351
self,
@@ -3343,7 +3356,16 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeT_co, _DTypeT_co]):
33433356
*,
33443357
order: _OrderACF = "C",
33453358
copy: py_bool | None = None,
3346-
) -> ndarray[_nt.Shape3, _DTypeT_co]: ...
3359+
) -> ndarray[_nt.Rank3, _DTypeT_co]: ...
3360+
@overload # (index, index, index, index)
3361+
def reshape(
3362+
self,
3363+
size1: _nt.Shape4,
3364+
/,
3365+
*,
3366+
order: _OrderACF = "C",
3367+
copy: py_bool | None = None,
3368+
) -> ndarray[_nt.Rank4, _DTypeT_co]: ...
33473369
@overload # (index, index, index, index)
33483370
def reshape(
33493371
self,
@@ -3355,7 +3377,7 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeT_co, _DTypeT_co]):
33553377
*,
33563378
order: _OrderACF = "C",
33573379
copy: py_bool | None = None,
3358-
) -> ndarray[_nt.Shape4, _DTypeT_co]: ...
3380+
) -> ndarray[_nt.Rank4, _DTypeT_co]: ...
33593381
@overload # (int, *(index, ...))
33603382
def reshape(
33613383
self,
@@ -3364,7 +3386,7 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeT_co, _DTypeT_co]):
33643386
*shape: CanIndex,
33653387
order: _OrderACF = "C",
33663388
copy: py_bool | None = None,
3367-
) -> ndarray[_nt.Shape, _DTypeT_co]: ...
3389+
) -> ndarray[Incomplete, _DTypeT_co]: ...
33683390
@overload # (sequence[index])
33693391
def reshape(
33703392
self,
@@ -3373,7 +3395,7 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeT_co, _DTypeT_co]):
33733395
*,
33743396
order: _OrderACF = "C",
33753397
copy: py_bool | None = None,
3376-
) -> ndarray[_nt.Shape, _DTypeT_co]: ...
3398+
) -> ndarray[Incomplete, _DTypeT_co]: ...
33773399

33783400
#
33793401
@overload
@@ -3405,23 +3427,23 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeT_co, _DTypeT_co]):
34053427
@overload # (dtype: dtype[T])
34063428
def view(self, /, dtype: _DTypeLike[_ScalarT]) -> _nt.Array[_ScalarT, _ShapeT_co]: ...
34073429
@overload # (type: matrix)
3408-
def view(self, /, *, type: type[matrix[Any, Any]]) -> matrix[_nt.Shape2, _DTypeT_co]: ...
3430+
def view(self, /, *, type: type[matrix[Any, Any]]) -> matrix[_nt.Rank2, _DTypeT_co]: ...
34093431
@overload # (_: matrix)
3410-
def view(self, /, dtype: type[matrix[Any, Any]]) -> matrix[_nt.Shape2, _DTypeT_co]: ...
3432+
def view(self, /, dtype: type[matrix[Any, Any]]) -> matrix[_nt.Rank2, _DTypeT_co]: ...
34113433
@overload # (dtype: T, type: matrix)
34123434
def view(
34133435
self,
34143436
/,
34153437
dtype: _DTypeT | _HasDType[_DTypeT],
34163438
type: type[matrix[Any, Any]],
3417-
) -> matrix[_nt.Shape2, _DTypeT]: ...
3439+
) -> matrix[_nt.Rank2, _DTypeT]: ...
34183440
@overload # (dtype: dtype[T], type: matrix)
34193441
def view(
34203442
self,
34213443
/,
34223444
dtype: _DTypeLike[_ScalarT],
34233445
type: type[matrix[Any, Any]],
3424-
) -> matrix[_nt.Shape2, dtype[_ScalarT]]: ...
3446+
) -> _nt.Matrix[_ScalarT]: ...
34253447
@overload # (type: recarray)
34263448
def view(
34273449
self,
@@ -3560,9 +3582,9 @@ class generic(_ArrayOrScalarCommon, Generic[_ItemT_co]):
35603582

35613583
#
35623584
@overload
3563-
def __array__(self, dtype: None = None, /) -> ndarray[_nt.Shape0, dtype[Self]]: ...
3585+
def __array__(self, dtype: None = None, /) -> _nt.Array0D[Self]: ...
35643586
@overload
3565-
def __array__(self, dtype: _DTypeT, /) -> ndarray[_nt.Shape0, _DTypeT]: ...
3587+
def __array__(self, dtype: _DTypeT, /) -> ndarray[_nt.Rank0, _DTypeT]: ...
35663588

35673589
#
35683590
@overload
@@ -3576,11 +3598,11 @@ class generic(_ArrayOrScalarCommon, Generic[_ItemT_co]):
35763598
@overload
35773599
def __array_wrap__(
35783600
self,
3579-
array: ndarray[_ShapeT_1nd, _DTypeT],
3601+
array: ndarray[_Shape1NDT, _DTypeT],
35803602
context: tuple[ufunc, tuple[object, ...], int] | None = None,
35813603
return_scalar: py_bool = True,
35823604
/,
3583-
) -> ndarray[_ShapeT_1nd, _DTypeT]: ...
3605+
) -> ndarray[_Shape1NDT, _DTypeT]: ...
35843606
@overload
35853607
def __array_wrap__(
35863608
self,

0 commit comments

Comments
 (0)