Skip to content

Commit 6cd5ef2

Browse files
samukwekuericmjl
andauthored
[ENH] faster groupby top k (#1101)
* skeleton code * fix tests, docstrings, examples * changelog * minor edits * docstrings for tests * allow by to be a list * updates based on feedback Co-authored-by: Eric Ma <[email protected]>
1 parent 4867b20 commit 6cd5ef2

File tree

3 files changed

+127
-69
lines changed

3 files changed

+127
-69
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
- [ENH] Allow column selection/renaming within conditional_join. #1102 @samukweku.
77
- [ENH] New decorator `deprecated_kwargs` for breaking API. #1103 @Zeroto521
88
- [ENH] Extend select_columns to support non-string columns. #1105 @samukweku
9+
- [ENH] Performance improvement for groupby_topk. #1093 @samukweku
910

1011
## [v0.23.1] - 2022-05-03
1112

janitor/functions/groupby_topk.py

Lines changed: 83 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -1,99 +1,124 @@
11
"""Implementation of the `groupby_topk` function"""
2-
from typing import Dict, Hashable
2+
from typing import Hashable, Union
33
import pandas_flavor as pf
44
import pandas as pd
55

66
from janitor.utils import check_column
7+
from janitor.utils import check, deprecated_alias
78

89

910
@pf.register_dataframe_method
11+
@deprecated_alias(groupby_column_name="by", sort_column_name="column")
1012
def groupby_topk(
1113
df: pd.DataFrame,
12-
groupby_column_name: Hashable,
13-
sort_column_name: Hashable,
14+
by: Union[list, Hashable],
15+
column: Hashable,
1416
k: int,
15-
sort_values_kwargs: Dict = None,
17+
dropna: bool = True,
18+
ascending: bool = True,
19+
ignore_index: bool = True,
1620
) -> pd.DataFrame:
1721
"""
1822
Return top `k` rows from a groupby of a set of columns.
1923
20-
Returns a DataFrame that has the top `k` values grouped by `groupby_column_name`
21-
and sorted by `sort_column_name`.
22-
Additional parameters to the sorting (such as `ascending=True`)
23-
can be passed using `sort_values_kwargs`.
24+
Returns a DataFrame that has the top `k` values per `column`,
25+
grouped by `by`. Under the hood it uses `nlargest/nsmallest`,
26+
for numeric columns, which avoids sorting the entire dataframe,
27+
and is usually more performant. For non-numeric columns, `pd.sort_values`
28+
is used.
29+
No sorting is done to the `by` column(s); the order is maintained
30+
in the final output.
2431
25-
List of all sort_values() parameters can be found
26-
[here](https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.sort_values.html).
2732
2833
Example:
2934
3035
>>> import pandas as pd
3136
>>> import janitor
32-
>>> df = pd.DataFrame({
33-
... "age": [20, 23, 22, 43, 21],
34-
... "id": [1, 4, 6, 2, 5],
35-
... "result": ["pass", "pass", "fail", "pass", "fail"]
36-
... })
37+
>>> df = pd.DataFrame(
38+
... {
39+
... "age": [20, 23, 22, 43, 21],
40+
... "id": [1, 4, 6, 2, 5],
41+
... "result": ["pass", "pass", "fail", "pass", "fail"],
42+
... }
43+
... )
3744
>>> df
3845
age id result
3946
0 20 1 pass
4047
1 23 4 pass
4148
2 22 6 fail
4249
3 43 2 pass
4350
4 21 5 fail
44-
>>> df.groupby_topk('result', 'age', 3) # Ascending top 3
45-
... # doctest: +NORMALIZE_WHITESPACE
46-
age id result
47-
result
48-
fail 4 21 5 fail
49-
2 22 6 fail
50-
pass 0 20 1 pass
51-
1 23 4 pass
52-
3 43 2 pass
53-
>>> df.groupby_topk('result', 'age', 2, {'ascending':False}) # Descending top 2
54-
... # doctest: +NORMALIZE_WHITESPACE
55-
age id result
56-
result
57-
fail 2 22 6 fail
58-
4 21 5 fail
59-
pass 3 43 2 pass
60-
1 23 4 pass
51+
52+
Ascending top 3:
53+
54+
>>> df.groupby_topk(by="result", column="age", k=3)
55+
age id result
56+
0 20 1 pass
57+
1 23 4 pass
58+
2 43 2 pass
59+
3 21 5 fail
60+
4 22 6 fail
61+
62+
Descending top 2:
63+
64+
>>> df.groupby_topk(
65+
... by="result", column="age", k=2, ascending=False, ignore_index=False
66+
... )
67+
age id result
68+
3 43 2 pass
69+
1 23 4 pass
70+
2 22 6 fail
71+
4 21 5 fail
6172
6273
6374
:param df: A pandas DataFrame.
64-
:param groupby_column_name: Column name to group input DataFrame `df` by.
65-
:param sort_column_name: Name of the column to sort along the
66-
input DataFrame `df`.
67-
:param k: Number of top rows to return from each group after sorting.
68-
:param sort_values_kwargs: Arguments to be passed to sort_values function.
69-
:returns: A pandas DataFrame with top `k` rows that are grouped by
70-
`groupby_column_name` column with each group sorted along the
71-
column `sort_column_name`.
75+
:param by: Column name(s) to group input DataFrame `df` by.
76+
:param column: Name of the column that determines `k` rows
77+
to return.
78+
:param k: Number of top rows to return for each group.
79+
:param dropna: If `True`, and `NA` values exist in `by`, the `NA`
80+
values are not used in the groupby computation to get the relevant
81+
`k` rows. If `False`, and `NA` values exist in `by`, then the `NA`
82+
values are used in the groupby computation to get the relevant
83+
`k` rows. The default is `True`.
84+
:param ascending: Default is `True`. If `True`, the smallest top `k` rows,
85+
determined by `column` are returned; if `False, the largest top `k` rows,
86+
determined by `column` are returned.
87+
:param ignore_index: Default `True`. If `True`,
88+
the original index is ignored. If `False`, the original index
89+
for the top `k` rows is retained.
90+
:returns: A pandas DataFrame with top `k` rows per `column`, grouped by `by`.
7291
:raises ValueError: if `k` is less than 1.
73-
:raises ValueError: if `groupby_column_name` not in DataFrame `df`.
74-
:raises ValueError: if `sort_column_name` not in DataFrame `df`.
75-
:raises KeyError: if `inplace:True` is present in `sort_values_kwargs`.
7692
""" # noqa: E501
7793

78-
# Convert the default sort_values_kwargs from None to empty Dict
79-
sort_values_kwargs = sort_values_kwargs or {}
94+
if isinstance(by, Hashable):
95+
by = [by]
8096

81-
# Check if groupby_column_name and sort_column_name exists in the DataFrame
82-
check_column(df, [groupby_column_name, sort_column_name])
97+
check("by", by, [Hashable, list])
98+
99+
check_column(df, [column])
100+
check_column(df, by)
83101

84-
# Check if k is greater than 0.
85102
if k < 1:
86103
raise ValueError(
87-
"Numbers of rows per group to be returned must be greater than 0."
104+
"Numbers of rows per group "
105+
"to be returned must be greater than 0."
88106
)
89107

90-
# Check if inplace:True in sort values kwargs because it returns None
91-
if (
92-
"inplace" in sort_values_kwargs.keys()
93-
and sort_values_kwargs["inplace"]
94-
):
95-
raise KeyError("Cannot use `inplace=True` in `sort_values_kwargs`.")
108+
indices = df.groupby(by=by, dropna=dropna, sort=False, observed=True)
109+
indices = indices[column]
110+
111+
try:
112+
if ascending:
113+
indices = indices.nsmallest(n=k)
114+
else:
115+
indices = indices.nlargest(n=k)
116+
except TypeError:
117+
indices = indices.apply(
118+
lambda d: d.sort_values(ascending=ascending).head(k)
119+
)
96120

97-
return df.groupby(groupby_column_name).apply(
98-
lambda d: d.sort_values(sort_column_name, **sort_values_kwargs).head(k)
99-
)
121+
indices = indices.index.get_level_values(-1)
122+
if ignore_index:
123+
return df.loc[indices].reset_index(drop=True)
124+
return df.loc[indices]

tests/functions/test_groupby_topk.py

Lines changed: 43 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
@pytest.fixture
77
def df():
8+
"""fixture for groupby_topk"""
89
return pd.DataFrame(
910
[
1011
{"age": 22, "major": "science", "ID": 145, "result": "pass"},
@@ -17,44 +18,75 @@ def df():
1718
)
1819

1920

21+
def test_dtype_by(df):
22+
"""Check dtype for by."""
23+
with pytest.raises(TypeError):
24+
df.groupby_topk(by={"result"}, column="age", k=2)
25+
26+
2027
def test_ascending_groupby_k_2(df):
2128
"""Test ascending group by, k=2"""
22-
expected = df.groupby("result").apply(
23-
lambda d: d.sort_values("age").head(2)
29+
expected = (
30+
df.groupby("result", sort=False)
31+
.apply(lambda d: d.sort_values("age").head(2))
32+
.droplevel(0)
33+
)
34+
assert_frame_equal(
35+
df.groupby_topk("result", "age", 2, ignore_index=False), expected
36+
)
37+
38+
39+
def test_ascending_groupby_non_numeric(df):
40+
"""Test output for non-numeric column"""
41+
expected = (
42+
df.groupby("result", sort=False)
43+
.apply(lambda d: d.sort_values("major").head(2))
44+
.droplevel(0)
45+
)
46+
assert_frame_equal(
47+
df.groupby_topk("result", "major", 2, ignore_index=False), expected
2448
)
25-
assert_frame_equal(df.groupby_topk("result", "age", 2), expected)
2649

2750

2851
def test_descending_groupby_k_3(df):
2952
"""Test descending group by, k=3"""
30-
expected = df.groupby("result").apply(
31-
lambda d: d.sort_values("age", ascending=False).head(3)
53+
expected = (
54+
df.groupby("result", sort=False)
55+
.apply(lambda d: d.sort_values("age", ascending=False).head(3))
56+
.droplevel(0)
57+
.reset_index(drop=True)
3258
)
3359
assert_frame_equal(
34-
df.groupby_topk("result", "age", 3, {"ascending": False}), expected
60+
df.groupby_topk("result", "age", 3, ascending=False), expected
3561
)
3662

3763

3864
def test_wrong_groupby_column_name(df):
3965
"""Raise Value Error if wrong groupby column name is provided."""
40-
with pytest.raises(ValueError):
66+
with pytest.raises(
67+
ValueError, match="RESULT not present in dataframe columns!"
68+
):
4169
df.groupby_topk("RESULT", "age", 3)
4270

4371

4472
def test_wrong_sort_column_name(df):
4573
"""Raise Value Error if wrong sort column name is provided."""
46-
with pytest.raises(ValueError):
74+
with pytest.raises(
75+
ValueError, match="Age not present in dataframe columns!"
76+
):
4777
df.groupby_topk("result", "Age", 3)
4878

4979

5080
def test_negative_k(df):
5181
"""Raises Value Error if k is less than 1 (negative or 0)."""
52-
with pytest.raises(ValueError):
82+
with pytest.raises(
83+
ValueError,
84+
match="Numbers of rows per group.+",
85+
):
5386
df.groupby_topk("result", "age", -2)
54-
with pytest.raises(ValueError):
55-
df.groupby_topk("result", "age", 0)
5687

5788

89+
@pytest.mark.xfail(reason="sort_value_kwargs parameter deprecated.")
5890
def test_inplace(df):
5991
"""Raise Key Error if inplace is True in sort_values_kwargs"""
6092
with pytest.raises(KeyError):

0 commit comments

Comments
 (0)