diff --git a/pandas-stubs/core/frame.pyi b/pandas-stubs/core/frame.pyi index 0acd9c855..215f14711 100644 --- a/pandas-stubs/core/frame.pyi +++ b/pandas-stubs/core/frame.pyi @@ -1224,7 +1224,7 @@ class DataFrame(NDFrame, OpsMixin): @overload def apply( self, - f: Callable[..., S1], + f: Callable[..., S1 | NAType], axis: AxisIndex = ..., raw: _bool = ..., result_type: None = ..., @@ -1248,7 +1248,7 @@ class DataFrame(NDFrame, OpsMixin): @overload def apply( self, - f: Callable[..., S1], + f: Callable[..., S1 | NAType], axis: Axis = ..., raw: _bool = ..., args: Any = ..., @@ -1309,7 +1309,7 @@ class DataFrame(NDFrame, OpsMixin): @overload def apply( self, - f: Callable[..., S1], + f: Callable[..., S1 | NAType], raw: _bool = ..., result_type: None = ..., args: Any = ..., diff --git a/tests/test_frame.py b/tests/test_frame.py index e17cff54c..8ce0dc356 100644 --- a/tests/test_frame.py +++ b/tests/test_frame.py @@ -43,6 +43,7 @@ ) import xarray as xr +from pandas._libs.missing import NAType from pandas._typing import Scalar from tests import ( @@ -578,6 +579,9 @@ def test_types_apply() -> None: def returns_scalar(x: pd.Series) -> int: return 2 + def returns_scalar_na(x: pd.Series) -> int | NAType: + return 2 if (x < 5).all() else pd.NA + def returns_series(x: pd.Series) -> pd.Series: return x**2 @@ -604,6 +608,11 @@ def gethead(s: pd.Series, y: int) -> pd.Series: check( assert_type(df.apply(returns_scalar), "pd.Series[int]"), pd.Series, np.integer ) + check( + assert_type(df.apply(returns_scalar_na), "pd.Series[int]"), + pd.Series, + int, + ) check(assert_type(df.apply(returns_series), pd.DataFrame), pd.DataFrame) check(assert_type(df.apply(returns_listlike_of_3), pd.DataFrame), pd.DataFrame) check(assert_type(df.apply(returns_dict), pd.Series), pd.Series)