Skip to content

Commit 9c61400

Browse files
authored
Merge pull request numpy#27683 from bersbersbers/bersbersbers-27638
TYP: Improve `np.sum` and `np.mean` return types with given `dtype`
2 parents ccb00b2 + 95c0592 commit 9c61400

File tree

2 files changed

+84
-0
lines changed

2 files changed

+84
-0
lines changed

numpy/_core/fromnumeric.pyi

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -586,6 +586,16 @@ def clip(
586586
casting: _CastingKind = ...,
587587
) -> _ArrayType: ...
588588

589+
@overload
590+
def sum(
591+
a: _ArrayLike[_SCT],
592+
axis: None = ...,
593+
dtype: None = ...,
594+
out: None = ...,
595+
keepdims: Literal[False] = ...,
596+
initial: _NumberLike_co = ...,
597+
where: _ArrayLikeBool_co = ...,
598+
) -> _SCT: ...
589599
@overload
590600
def sum(
591601
a: _ArrayLike[_SCT],
@@ -595,8 +605,50 @@ def sum(
595605
keepdims: bool = ...,
596606
initial: _NumberLike_co = ...,
597607
where: _ArrayLikeBool_co = ...,
608+
) -> _SCT | NDArray[_SCT]: ...
609+
@overload
610+
def sum(
611+
a: ArrayLike,
612+
axis: None,
613+
dtype: _DTypeLike[_SCT],
614+
out: None = ...,
615+
keepdims: Literal[False] = ...,
616+
initial: _NumberLike_co = ...,
617+
where: _ArrayLikeBool_co = ...,
618+
) -> _SCT: ...
619+
@overload
620+
def sum(
621+
a: ArrayLike,
622+
axis: None = ...,
623+
*,
624+
dtype: _DTypeLike[_SCT],
625+
out: None = ...,
626+
keepdims: Literal[False] = ...,
627+
initial: _NumberLike_co = ...,
628+
where: _ArrayLikeBool_co = ...,
598629
) -> _SCT: ...
599630
@overload
631+
def sum(
632+
a: ArrayLike,
633+
axis: None | _ShapeLike,
634+
dtype: _DTypeLike[_SCT],
635+
out: None = ...,
636+
keepdims: bool = ...,
637+
initial: _NumberLike_co = ...,
638+
where: _ArrayLikeBool_co = ...,
639+
) -> _SCT | NDArray[_SCT]: ...
640+
@overload
641+
def sum(
642+
a: ArrayLike,
643+
axis: None | _ShapeLike = ...,
644+
*,
645+
dtype: _DTypeLike[_SCT],
646+
out: None = ...,
647+
keepdims: bool = ...,
648+
initial: _NumberLike_co = ...,
649+
where: _ArrayLikeBool_co = ...,
650+
) -> _SCT | NDArray[_SCT]: ...
651+
@overload
600652
def sum(
601653
a: ArrayLike,
602654
axis: None | _ShapeLike = ...,
@@ -1207,6 +1259,26 @@ def mean(
12071259
where: _ArrayLikeBool_co = ...,
12081260
) -> _SCT: ...
12091261
@overload
1262+
def mean(
1263+
a: _ArrayLikeComplex_co | _ArrayLikeObject_co,
1264+
axis: None,
1265+
dtype: _DTypeLike[_SCT],
1266+
out: None = ...,
1267+
keepdims: bool = ...,
1268+
*,
1269+
where: _ArrayLikeBool_co = ...,
1270+
) -> _SCT | NDArray[_SCT]: ...
1271+
@overload
1272+
def mean(
1273+
a: _ArrayLikeComplex_co | _ArrayLikeObject_co,
1274+
axis: None = ...,
1275+
*,
1276+
dtype: _DTypeLike[_SCT],
1277+
out: None = ...,
1278+
keepdims: bool = ...,
1279+
where: _ArrayLikeBool_co = ...,
1280+
) -> _SCT | NDArray[_SCT]: ...
1281+
@overload
12101282
def mean(
12111283
a: _ArrayLikeComplex_co | _ArrayLikeObject_co,
12121284
axis: None | _ShapeLike = ...,

numpy/typing/tests/data/reveal/fromnumeric.pyi

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,12 @@ assert_type(np.sum(AR_f4), np.float32)
169169
assert_type(np.sum(AR_b, axis=0), Any)
170170
assert_type(np.sum(AR_f4, axis=0), Any)
171171
assert_type(np.sum(AR_f4, out=AR_subclass), NDArraySubclass)
172+
assert_type(np.sum(AR_f4, dtype=np.float64), np.float64)
173+
assert_type(np.sum(AR_f4, None, np.float64), np.float64)
174+
assert_type(np.sum(AR_f4, dtype=np.float64, keepdims=False), np.float64)
175+
assert_type(np.sum(AR_f4, None, np.float64, keepdims=False), np.float64)
176+
assert_type(np.sum(AR_f4, dtype=np.float64, keepdims=True), np.float64 | npt.NDArray[np.float64])
177+
assert_type(np.sum(AR_f4, None, np.float64, keepdims=True), np.float64 | npt.NDArray[np.float64])
172178

173179
assert_type(np.all(b), np.bool)
174180
assert_type(np.all(f4), np.bool)
@@ -310,6 +316,12 @@ assert_type(np.mean(AR_f4, keepdims=True), Any)
310316
assert_type(np.mean(AR_f4, dtype=float), Any)
311317
assert_type(np.mean(AR_f4, dtype=np.float64), np.float64)
312318
assert_type(np.mean(AR_f4, out=AR_subclass), NDArraySubclass)
319+
assert_type(np.mean(AR_f4, dtype=np.float64), np.float64)
320+
assert_type(np.mean(AR_f4, None, np.float64), np.float64)
321+
assert_type(np.mean(AR_f4, dtype=np.float64, keepdims=False), np.float64)
322+
assert_type(np.mean(AR_f4, None, np.float64, keepdims=False), np.float64)
323+
assert_type(np.mean(AR_f4, dtype=np.float64, keepdims=True), np.float64 | npt.NDArray[np.float64])
324+
assert_type(np.mean(AR_f4, None, np.float64, keepdims=True), np.float64 | npt.NDArray[np.float64])
313325

314326
assert_type(np.std(AR_b), np.floating[Any])
315327
assert_type(np.std(AR_i8), np.floating[Any])

0 commit comments

Comments
 (0)