Skip to content

Commit 72276dc

Browse files
committed
ENH: Support skipna parameter in GroupBy prod, var, std and sem methods
1 parent 42bf375 commit 72276dc

File tree

8 files changed

+141
-20
lines changed

8 files changed

+141
-20
lines changed

doc/source/whatsnew/v3.0.0.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,9 +58,9 @@ Other enhancements
5858
- :meth:`Series.cummin` and :meth:`Series.cummax` now supports :class:`CategoricalDtype` (:issue:`52335`)
5959
- :meth:`Series.plot` now correctly handle the ``ylabel`` parameter for pie charts, allowing for explicit control over the y-axis label (:issue:`58239`)
6060
- :meth:`DataFrame.plot.scatter` argument ``c`` now accepts a column of strings, where rows with the same string are colored identically (:issue:`16827` and :issue:`16485`)
61+
- :class:`DataFrameGroupBy` and :class:`SeriesGroupBy` methods ``sum``, ``mean``, ``prod``, ``std``, ``var`` and ``sem`` now accept ``skipna`` parameter (:issue:`15675`)
6162
- :class:`Rolling` and :class:`Expanding` now support aggregations ``first`` and ``last`` (:issue:`33155`)
6263
- :func:`read_parquet` accepts ``to_pandas_kwargs`` which are forwarded to :meth:`pyarrow.Table.to_pandas` which enables passing additional keywords to customize the conversion to pandas, such as ``maps_as_pydicts`` to read the Parquet map data type as python dictionaries (:issue:`56842`)
63-
- :meth:`.DataFrameGroupBy.mean`, :meth:`.DataFrameGroupBy.sum`, :meth:`.SeriesGroupBy.mean` and :meth:`.SeriesGroupBy.sum` now accept ``skipna`` parameter (:issue:`15675`)
6464
- :meth:`.DataFrameGroupBy.transform`, :meth:`.SeriesGroupBy.transform`, :meth:`.DataFrameGroupBy.agg`, :meth:`.SeriesGroupBy.agg`, :meth:`.SeriesGroupBy.apply`, :meth:`.DataFrameGroupBy.apply` now support ``kurt`` (:issue:`40139`)
6565
- :meth:`DataFrameGroupBy.transform`, :meth:`SeriesGroupBy.transform`, :meth:`DataFrameGroupBy.agg`, :meth:`SeriesGroupBy.agg`, :meth:`RollingGroupby.apply`, :meth:`ExpandingGroupby.apply`, :meth:`Rolling.apply`, :meth:`Expanding.apply`, :meth:`DataFrame.apply` with ``engine="numba"`` now supports positional arguments passed as kwargs (:issue:`58995`)
6666
- :meth:`Rolling.agg`, :meth:`Expanding.agg` and :meth:`ExponentialMovingWindow.agg` now accept :class:`NamedAgg` aggregations through ``**kwargs`` (:issue:`28333`)

pandas/_libs/groupby.pyi

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ def group_prod(
7676
mask: np.ndarray | None,
7777
result_mask: np.ndarray | None = ...,
7878
min_count: int = ...,
79+
skipna: bool = ...,
7980
) -> None: ...
8081
def group_var(
8182
out: np.ndarray, # floating[:, ::1]
@@ -88,6 +89,7 @@ def group_var(
8889
result_mask: np.ndarray | None = ...,
8990
is_datetimelike: bool = ...,
9091
name: str = ...,
92+
skipna: bool = ...,
9193
) -> None: ...
9294
def group_skew(
9395
out: np.ndarray, # float64_t[:, ::1]

pandas/_libs/groupby.pyx

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -806,13 +806,14 @@ def group_prod(
806806
const uint8_t[:, ::1] mask,
807807
uint8_t[:, ::1] result_mask=None,
808808
Py_ssize_t min_count=0,
809+
bint skipna=True,
809810
) -> None:
810811
"""
811812
Only aggregates on axis=0
812813
"""
813814
cdef:
814815
Py_ssize_t i, j, N, K, lab, ncounts = len(counts)
815-
int64float_t val
816+
int64float_t val, nan_val
816817
int64float_t[:, ::1] prodx
817818
int64_t[:, ::1] nobs
818819
Py_ssize_t len_values = len(values), len_labels = len(labels)
@@ -825,6 +826,13 @@ def group_prod(
825826
prodx = np.ones((<object>out).shape, dtype=(<object>out).base.dtype)
826827

827828
N, K = (<object>values).shape
829+
if uses_mask:
830+
nan_val = 0
831+
elif int64float_t is int64_t or int64float_t is uint64_t:
832+
# This has no effect as int64 can't be nan. Setting to 0 to avoid type error
833+
nan_val = 0
834+
else:
835+
nan_val = NAN
828836

829837
with nogil:
830838
for i in range(N):
@@ -836,6 +844,13 @@ def group_prod(
836844
for j in range(K):
837845
val = values[i, j]
838846

847+
if not skipna and (
848+
(uses_mask and result_mask[lab, j]) or
849+
_treat_as_na(prodx[lab, j], False)
850+
):
851+
# If prod is already NA, no need to update it
852+
continue
853+
839854
if uses_mask:
840855
isna_entry = mask[i, j]
841856
else:
@@ -844,6 +859,11 @@ def group_prod(
844859
if not isna_entry:
845860
nobs[lab, j] += 1
846861
prodx[lab, j] *= val
862+
elif not skipna:
863+
if uses_mask:
864+
result_mask[lab, j] = True
865+
else:
866+
prodx[lab, j] = nan_val
847867

848868
_check_below_mincount(
849869
out, uses_mask, result_mask, ncounts, K, nobs, min_count, prodx
@@ -864,6 +884,7 @@ def group_var(
864884
uint8_t[:, ::1] result_mask=None,
865885
bint is_datetimelike=False,
866886
str name="var",
887+
bint skipna=True,
867888
) -> None:
868889
cdef:
869890
Py_ssize_t i, j, N, K, lab, ncounts = len(counts)
@@ -898,6 +919,16 @@ def group_var(
898919
for j in range(K):
899920
val = values[i, j]
900921

922+
if not skipna and (
923+
(uses_mask and result_mask[lab, j]) or
924+
(is_datetimelike and out[lab, j] == NPY_NAT) or
925+
_treat_as_na(out[lab, j], False)
926+
):
927+
# If aggregate is already NA, don't add to it. This is important for
928+
# datetimelike because adding a value to NPY_NAT may not result
929+
# in a NPY_NAT
930+
continue
931+
901932
if uses_mask:
902933
isna_entry = mask[i, j]
903934
elif is_datetimelike:
@@ -913,6 +944,12 @@ def group_var(
913944
oldmean = mean[lab, j]
914945
mean[lab, j] += (val - oldmean) / nobs[lab, j]
915946
out[lab, j] += (val - mean[lab, j]) * (val - oldmean)
947+
elif not skipna:
948+
nobs[lab, j] = 0
949+
if uses_mask:
950+
result_mask[lab, j] = True
951+
else:
952+
out[lab, j] = NAN
916953

917954
for i in range(ncounts):
918955
for j in range(K):

pandas/core/_numba/kernels/var_.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,7 @@ def grouped_var(
176176
ngroups: int,
177177
min_periods: int,
178178
ddof: int = 1,
179+
skipna: bool = True,
179180
) -> tuple[np.ndarray, list[int]]:
180181
N = len(labels)
181182

@@ -190,7 +191,15 @@ def grouped_var(
190191
lab = labels[i]
191192
val = values[i]
192193

193-
if lab < 0:
194+
if lab < 0 or np.isnan(output[lab]):
195+
continue
196+
197+
if not skipna and np.isnan(val):
198+
output[lab] = np.nan
199+
nobs_arr[lab] += 1
200+
comp_arr[lab] = np.nan
201+
consecutive_counts[lab] = 1
202+
prev_vals[lab] = np.nan
194203
continue
195204

196205
mean_x = means[lab]

pandas/core/groupby/groupby.py

Lines changed: 45 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2349,6 +2349,7 @@ def std(
23492349
engine: Literal["cython", "numba"] | None = None,
23502350
engine_kwargs: dict[str, bool] | None = None,
23512351
numeric_only: bool = False,
2352+
skipna: bool = True,
23522353
):
23532354
"""
23542355
Compute standard deviation of groups, excluding missing values.
@@ -2387,6 +2388,12 @@ def std(
23872388
23882389
numeric_only now defaults to ``False``.
23892390
2391+
skipna : bool, default True
2392+
Exclude NA/null values. If an entire row/column is NA, the result
2393+
will be NA.
2394+
2395+
.. versionadded:: 3.0.0
2396+
23902397
Returns
23912398
-------
23922399
Series or DataFrame
@@ -2441,14 +2448,16 @@ def std(
24412448
engine_kwargs,
24422449
min_periods=0,
24432450
ddof=ddof,
2451+
skipna=skipna,
24442452
)
24452453
)
24462454
else:
24472455
return self._cython_agg_general(
24482456
"std",
2449-
alt=lambda x: Series(x, copy=False).std(ddof=ddof),
2457+
alt=lambda x: Series(x, copy=False).std(ddof=ddof, skipna=skipna),
24502458
numeric_only=numeric_only,
24512459
ddof=ddof,
2460+
skipna=skipna,
24522461
)
24532462

24542463
@final
@@ -2460,6 +2469,7 @@ def var(
24602469
engine: Literal["cython", "numba"] | None = None,
24612470
engine_kwargs: dict[str, bool] | None = None,
24622471
numeric_only: bool = False,
2472+
skipna: bool = True,
24632473
):
24642474
"""
24652475
Compute variance of groups, excluding missing values.
@@ -2497,6 +2507,12 @@ def var(
24972507
24982508
numeric_only now defaults to ``False``.
24992509
2510+
skipna : bool, default True
2511+
Exclude NA/null values. If an entire row/column is NA, the result
2512+
will be NA.
2513+
2514+
.. versionadded:: 3.0.0
2515+
25002516
Returns
25012517
-------
25022518
Series or DataFrame
@@ -2550,13 +2566,15 @@ def var(
25502566
engine_kwargs,
25512567
min_periods=0,
25522568
ddof=ddof,
2569+
skipna=skipna,
25532570
)
25542571
else:
25552572
return self._cython_agg_general(
25562573
"var",
2557-
alt=lambda x: Series(x, copy=False).var(ddof=ddof),
2574+
alt=lambda x: Series(x, copy=False).var(ddof=ddof, skipna=skipna),
25582575
numeric_only=numeric_only,
25592576
ddof=ddof,
2577+
skipna=skipna,
25602578
)
25612579

25622580
@final
@@ -2686,7 +2704,9 @@ def _value_counts(
26862704
return result.__finalize__(self.obj, method="value_counts")
26872705

26882706
@final
2689-
def sem(self, ddof: int = 1, numeric_only: bool = False) -> NDFrameT:
2707+
def sem(
2708+
self, ddof: int = 1, numeric_only: bool = False, skipna: bool = True
2709+
) -> NDFrameT:
26902710
"""
26912711
Compute standard error of the mean of groups, excluding missing values.
26922712
@@ -2706,6 +2726,12 @@ def sem(self, ddof: int = 1, numeric_only: bool = False) -> NDFrameT:
27062726
27072727
numeric_only now defaults to ``False``.
27082728
2729+
skipna : bool, default True
2730+
Exclude NA/null values. If an entire row/column is NA, the result
2731+
will be NA.
2732+
2733+
.. versionadded:: 3.0.0
2734+
27092735
Returns
27102736
-------
27112737
Series or DataFrame
@@ -2780,9 +2806,10 @@ def sem(self, ddof: int = 1, numeric_only: bool = False) -> NDFrameT:
27802806
)
27812807
return self._cython_agg_general(
27822808
"sem",
2783-
alt=lambda x: Series(x, copy=False).sem(ddof=ddof),
2809+
alt=lambda x: Series(x, copy=False).sem(ddof=ddof, skipna=skipna),
27842810
numeric_only=numeric_only,
27852811
ddof=ddof,
2812+
skipna=skipna,
27862813
)
27872814

27882815
@final
@@ -2959,7 +2986,9 @@ def sum(
29592986
return result
29602987

29612988
@final
2962-
def prod(self, numeric_only: bool = False, min_count: int = 0) -> NDFrameT:
2989+
def prod(
2990+
self, numeric_only: bool = False, min_count: int = 0, skipna: bool = True
2991+
) -> NDFrameT:
29632992
"""
29642993
Compute prod of group values.
29652994
@@ -2976,6 +3005,12 @@ def prod(self, numeric_only: bool = False, min_count: int = 0) -> NDFrameT:
29763005
The required number of valid values to perform the operation. If fewer
29773006
than ``min_count`` non-NA values are present the result will be NA.
29783007
3008+
skipna : bool, default True
3009+
Exclude NA/null values. If an entire row/column is NA, the result
3010+
will be NA.
3011+
3012+
.. versionadded:: 3.0.0
3013+
29793014
Returns
29803015
-------
29813016
Series or DataFrame
@@ -3024,7 +3059,11 @@ def prod(self, numeric_only: bool = False, min_count: int = 0) -> NDFrameT:
30243059
2 30 72
30253060
"""
30263061
return self._agg_general(
3027-
numeric_only=numeric_only, min_count=min_count, alias="prod", npfunc=np.prod
3062+
numeric_only=numeric_only,
3063+
min_count=min_count,
3064+
skipna=skipna,
3065+
alias="prod",
3066+
npfunc=np.prod,
30283067
)
30293068

30303069
@final

pandas/tests/groupby/aggregate/test_numba.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ def test_multifunc_numba_vs_cython_frame(agg_kwargs):
186186
tm.assert_frame_equal(result, expected)
187187

188188

189-
@pytest.mark.parametrize("func", ["sum", "mean"])
189+
@pytest.mark.parametrize("func", ["sum", "mean", "var", "std"])
190190
def test_multifunc_numba_vs_cython_frame_noskipna(func):
191191
pytest.importorskip("numba")
192192
data = DataFrame(

pandas/tests/groupby/test_api.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -176,14 +176,13 @@ def test_frame_consistency(groupby_func):
176176
elif groupby_func in ("max", "min"):
177177
exclude_expected = {"axis", "kwargs", "skipna"}
178178
exclude_result = {"min_count", "engine", "engine_kwargs"}
179-
elif groupby_func in ("sum", "mean"):
179+
elif groupby_func in ("sum", "mean", "std", "var"):
180180
exclude_expected = {"axis", "kwargs"}
181181
exclude_result = {"engine", "engine_kwargs"}
182-
elif groupby_func in ("std", "var"):
183-
exclude_expected = {"axis", "kwargs", "skipna"}
184-
exclude_result = {"engine", "engine_kwargs"}
185-
elif groupby_func in ("median", "prod", "sem"):
182+
elif groupby_func in ("median"):
186183
exclude_expected = {"axis", "kwargs", "skipna"}
184+
elif groupby_func in ("prod", "sem"):
185+
exclude_expected = {"axis", "kwargs"}
187186
elif groupby_func in ("bfill", "ffill"):
188187
exclude_expected = {"inplace", "axis", "limit_area"}
189188
elif groupby_func in ("cummax", "cummin"):
@@ -237,14 +236,13 @@ def test_series_consistency(request, groupby_func):
237236
elif groupby_func in ("max", "min"):
238237
exclude_expected = {"axis", "kwargs", "skipna"}
239238
exclude_result = {"min_count", "engine", "engine_kwargs"}
240-
elif groupby_func in ("sum", "mean"):
239+
elif groupby_func in ("sum", "mean", "std", "var"):
241240
exclude_expected = {"axis", "kwargs"}
242241
exclude_result = {"engine", "engine_kwargs"}
243-
elif groupby_func in ("std", "var"):
244-
exclude_expected = {"axis", "kwargs", "skipna"}
245-
exclude_result = {"engine", "engine_kwargs"}
246-
elif groupby_func in ("median", "prod", "sem"):
242+
elif groupby_func in ("median"):
247243
exclude_expected = {"axis", "kwargs", "skipna"}
244+
elif groupby_func in ("prod", "sem"):
245+
exclude_expected = {"axis", "kwargs"}
248246
elif groupby_func in ("bfill", "ffill"):
249247
exclude_expected = {"inplace", "axis", "limit_area"}
250248
elif groupby_func in ("cummax", "cummin"):

pandas/tests/groupby/test_reductions.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -514,6 +514,42 @@ def test_sum_skipna_object(skipna):
514514
tm.assert_series_equal(result, expected)
515515

516516

517+
@pytest.mark.parametrize(
518+
"func, values, dtype, result_dtype",
519+
[
520+
("prod", [0, 1, 3, np.nan, 4, 5, 6, 7, -8, 9], "float64", "float64"),
521+
("prod", [0, -1, 3, 4, 5, np.nan, 6, 7, 8, 9], "Float64", "Float64"),
522+
("prod", [0, 1, 3, -4, 5, 6, 7, -8, np.nan, 9], "Int64", "Int64"),
523+
("var", [0, -1, 3, 4, np.nan, 5, 6, 7, 8, 9], "float64", "float64"),
524+
("var", [0, 1, 3, -4, 5, 6, 7, -8, 9, np.nan], "Float64", "Float64"),
525+
("var", [0, -1, 3, 4, 5, -6, 7, np.nan, 8, 9], "Int64", "Float64"),
526+
("std", [0, 1, 3, -4, 5, 6, 7, -8, np.nan, 9], "float64", "float64"),
527+
("std", [0, -1, 3, 4, 5, -6, 7, np.nan, 8, 9], "Float64", "Float64"),
528+
("std", [0, 1, 3, -4, 5, 6, 7, -8, 9, np.nan], "Int64", "Float64"),
529+
("sem", [0, -1, 3, 4, 5, -6, 7, np.nan, 8, 9], "float64", "float64"),
530+
("sem", [0, 1, 3, -4, 5, 6, 7, -8, np.nan, 9], "Float64", "Float64"),
531+
("sem", [0, -1, 3, 4, 5, -6, 7, 8, 9, np.nan], "Int64", "Float64"),
532+
],
533+
)
534+
def test_multifunc_skipna(func, values, dtype, result_dtype, skipna):
535+
# GH#15675
536+
df = DataFrame(
537+
{
538+
"val": values,
539+
"cat": ["A", "B"] * 5,
540+
}
541+
).astype({"val": dtype})
542+
# We need to recast the expected values to the result_dtype as some operations
543+
# change the dtype
544+
expected = (
545+
df.groupby("cat")["val"]
546+
.apply(lambda x: getattr(x, func)(skipna=skipna))
547+
.astype(result_dtype)
548+
)
549+
result = getattr(df.groupby("cat")["val"], func)(skipna=skipna)
550+
tm.assert_series_equal(result, expected)
551+
552+
517553
def test_cython_median():
518554
arr = np.random.default_rng(2).standard_normal(1000)
519555
arr[::2] = np.nan

0 commit comments

Comments
 (0)