Skip to content

Commit fc8a569

Browse files
authored
Merge pull request numpy#27157 from guan404ming/mean-type-overload
TYP: add td64 overload for `mean`
2 parents 5a81fef + f628ce1 commit fc8a569

File tree

3 files changed

+16
-0
lines changed

3 files changed

+16
-0
lines changed

numpy/_core/fromnumeric.pyi

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ from numpy import (
1111
float16,
1212
floating,
1313
complexfloating,
14+
timedelta64,
1415
object_,
1516
generic,
1617
_OrderKACF,
@@ -35,6 +36,7 @@ from numpy._typing import (
3536
_ArrayLikeFloat_co,
3637
_ArrayLikeComplex_co,
3738
_ArrayLikeObject_co,
39+
_ArrayLikeTD64_co,
3840
_IntLike_co,
3941
_BoolLike_co,
4042
_ComplexLike_co,
@@ -1062,6 +1064,16 @@ def mean(
10621064
where: _ArrayLikeBool_co = ...,
10631065
) -> complexfloating[Any, Any]: ...
10641066
@overload
1067+
def mean(
1068+
a: _ArrayLikeTD64_co,
1069+
axis: None = ...,
1070+
dtype: None = ...,
1071+
out: None = ...,
1072+
keepdims: Literal[False] = ...,
1073+
*,
1074+
where: _ArrayLikeBool_co = ...,
1075+
) -> timedelta64: ...
1076+
@overload
10651077
def mean(
10661078
a: _ArrayLikeComplex_co | _ArrayLikeObject_co,
10671079
axis: None | _ShapeLike = ...,

numpy/typing/tests/data/fail/fromnumeric.pyi

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import numpy.typing as npt
66
A = np.array(True, ndmin=2, dtype=bool)
77
A.setflags(write=False)
88
AR_U: npt.NDArray[np.str_]
9+
AR_M: npt.NDArray[np.datetime64]
910

1011
a = np.bool(True)
1112

@@ -147,6 +148,7 @@ np.mean(a, axis=1.0) # E: No overload variant
147148
np.mean(a, out=False) # E: No overload variant
148149
np.mean(a, keepdims=1.0) # E: No overload variant
149150
np.mean(AR_U) # E: incompatible type
151+
np.mean(AR_M) # E: incompatible type
150152

151153
np.std(a, axis=1.0) # E: No overload variant
152154
np.std(a, out=False) # E: No overload variant

numpy/typing/tests/data/reveal/fromnumeric.pyi

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ AR_u8: npt.NDArray[np.uint64]
2121
AR_i8: npt.NDArray[np.int64]
2222
AR_O: npt.NDArray[np.object_]
2323
AR_subclass: NDArraySubclass
24+
AR_m: npt.NDArray[np.timedelta64]
2425

2526
b: np.bool
2627
f4: np.float32
@@ -294,6 +295,7 @@ assert_type(np.around(AR_f4, out=AR_subclass), NDArraySubclass)
294295
assert_type(np.mean(AR_b), np.floating[Any])
295296
assert_type(np.mean(AR_i8), np.floating[Any])
296297
assert_type(np.mean(AR_f4), np.floating[Any])
298+
assert_type(np.mean(AR_m), np.timedelta64)
297299
assert_type(np.mean(AR_c16), np.complexfloating[Any, Any])
298300
assert_type(np.mean(AR_O), Any)
299301
assert_type(np.mean(AR_f4, axis=0), Any)

0 commit comments

Comments
 (0)