|
1 | 1 | """Implementation of the `groupby_topk` function"""
|
2 |
| -from typing import Dict, Hashable |
| 2 | +from typing import Hashable, Union |
3 | 3 | import pandas_flavor as pf
|
4 | 4 | import pandas as pd
|
5 | 5 |
|
6 | 6 | from janitor.utils import check_column
|
| 7 | +from janitor.utils import check, deprecated_alias |
7 | 8 |
|
8 | 9 |
|
9 | 10 | @pf.register_dataframe_method
|
| 11 | +@deprecated_alias(groupby_column_name="by", sort_column_name="column") |
10 | 12 | def groupby_topk(
|
11 | 13 | df: pd.DataFrame,
|
12 |
| - groupby_column_name: Hashable, |
13 |
| - sort_column_name: Hashable, |
| 14 | + by: Union[list, Hashable], |
| 15 | + column: Hashable, |
14 | 16 | k: int,
|
15 |
| - sort_values_kwargs: Dict = None, |
| 17 | + dropna: bool = True, |
| 18 | + ascending: bool = True, |
| 19 | + ignore_index: bool = True, |
16 | 20 | ) -> pd.DataFrame:
|
17 | 21 | """
|
18 | 22 | Return top `k` rows from a groupby of a set of columns.
|
19 | 23 |
|
20 |
| - Returns a DataFrame that has the top `k` values grouped by `groupby_column_name` |
21 |
| - and sorted by `sort_column_name`. |
22 |
| - Additional parameters to the sorting (such as `ascending=True`) |
23 |
| - can be passed using `sort_values_kwargs`. |
| 24 | + Returns a DataFrame that has the top `k` values per `column`, |
| 25 | + grouped by `by`. Under the hood it uses `nlargest/nsmallest`, |
| 26 | + for numeric columns, which avoids sorting the entire dataframe, |
| 27 | + and is usually more performant. For non-numeric columns, `pd.sort_values` |
| 28 | + is used. |
| 29 | + No sorting is done to the `by` column(s); the order is maintained |
| 30 | + in the final output. |
24 | 31 |
|
25 |
| - List of all sort_values() parameters can be found |
26 |
| - [here](https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.sort_values.html). |
27 | 32 |
|
28 | 33 | Example:
|
29 | 34 |
|
30 | 35 | >>> import pandas as pd
|
31 | 36 | >>> import janitor
|
32 |
| - >>> df = pd.DataFrame({ |
33 |
| - ... "age": [20, 23, 22, 43, 21], |
34 |
| - ... "id": [1, 4, 6, 2, 5], |
35 |
| - ... "result": ["pass", "pass", "fail", "pass", "fail"] |
36 |
| - ... }) |
| 37 | + >>> df = pd.DataFrame( |
| 38 | + ... { |
| 39 | + ... "age": [20, 23, 22, 43, 21], |
| 40 | + ... "id": [1, 4, 6, 2, 5], |
| 41 | + ... "result": ["pass", "pass", "fail", "pass", "fail"], |
| 42 | + ... } |
| 43 | + ... ) |
37 | 44 | >>> df
|
38 | 45 | age id result
|
39 | 46 | 0 20 1 pass
|
40 | 47 | 1 23 4 pass
|
41 | 48 | 2 22 6 fail
|
42 | 49 | 3 43 2 pass
|
43 | 50 | 4 21 5 fail
|
44 |
| - >>> df.groupby_topk('result', 'age', 3) # Ascending top 3 |
45 |
| - ... # doctest: +NORMALIZE_WHITESPACE |
46 |
| - age id result |
47 |
| - result |
48 |
| - fail 4 21 5 fail |
49 |
| - 2 22 6 fail |
50 |
| - pass 0 20 1 pass |
51 |
| - 1 23 4 pass |
52 |
| - 3 43 2 pass |
53 |
| - >>> df.groupby_topk('result', 'age', 2, {'ascending':False}) # Descending top 2 |
54 |
| - ... # doctest: +NORMALIZE_WHITESPACE |
55 |
| - age id result |
56 |
| - result |
57 |
| - fail 2 22 6 fail |
58 |
| - 4 21 5 fail |
59 |
| - pass 3 43 2 pass |
60 |
| - 1 23 4 pass |
| 51 | +
|
| 52 | + Ascending top 3: |
| 53 | +
|
| 54 | + >>> df.groupby_topk(by="result", column="age", k=3) |
| 55 | + age id result |
| 56 | + 0 20 1 pass |
| 57 | + 1 23 4 pass |
| 58 | + 2 43 2 pass |
| 59 | + 3 21 5 fail |
| 60 | + 4 22 6 fail |
| 61 | +
|
| 62 | + Descending top 2: |
| 63 | +
|
| 64 | + >>> df.groupby_topk( |
| 65 | + ... by="result", column="age", k=2, ascending=False, ignore_index=False |
| 66 | + ... ) |
| 67 | + age id result |
| 68 | + 3 43 2 pass |
| 69 | + 1 23 4 pass |
| 70 | + 2 22 6 fail |
| 71 | + 4 21 5 fail |
61 | 72 |
|
62 | 73 |
|
63 | 74 | :param df: A pandas DataFrame.
|
64 |
| - :param groupby_column_name: Column name to group input DataFrame `df` by. |
65 |
| - :param sort_column_name: Name of the column to sort along the |
66 |
| - input DataFrame `df`. |
67 |
| - :param k: Number of top rows to return from each group after sorting. |
68 |
| - :param sort_values_kwargs: Arguments to be passed to sort_values function. |
69 |
| - :returns: A pandas DataFrame with top `k` rows that are grouped by |
70 |
| - `groupby_column_name` column with each group sorted along the |
71 |
| - column `sort_column_name`. |
| 75 | + :param by: Column name(s) to group input DataFrame `df` by. |
| 76 | + :param column: Name of the column that determines `k` rows |
| 77 | + to return. |
| 78 | + :param k: Number of top rows to return for each group. |
| 79 | + :param dropna: If `True`, and `NA` values exist in `by`, the `NA` |
| 80 | + values are not used in the groupby computation to get the relevant |
| 81 | + `k` rows. If `False`, and `NA` values exist in `by`, then the `NA` |
| 82 | + values are used in the groupby computation to get the relevant |
| 83 | + `k` rows. The default is `True`. |
| 84 | + :param ascending: Default is `True`. If `True`, the smallest top `k` rows, |
| 85 | + determined by `column` are returned; if `False, the largest top `k` rows, |
| 86 | + determined by `column` are returned. |
| 87 | + :param ignore_index: Default `True`. If `True`, |
| 88 | + the original index is ignored. If `False`, the original index |
| 89 | + for the top `k` rows is retained. |
| 90 | + :returns: A pandas DataFrame with top `k` rows per `column`, grouped by `by`. |
72 | 91 | :raises ValueError: if `k` is less than 1.
|
73 |
| - :raises ValueError: if `groupby_column_name` not in DataFrame `df`. |
74 |
| - :raises ValueError: if `sort_column_name` not in DataFrame `df`. |
75 |
| - :raises KeyError: if `inplace:True` is present in `sort_values_kwargs`. |
76 | 92 | """ # noqa: E501
|
77 | 93 |
|
78 |
| - # Convert the default sort_values_kwargs from None to empty Dict |
79 |
| - sort_values_kwargs = sort_values_kwargs or {} |
| 94 | + if isinstance(by, Hashable): |
| 95 | + by = [by] |
80 | 96 |
|
81 |
| - # Check if groupby_column_name and sort_column_name exists in the DataFrame |
82 |
| - check_column(df, [groupby_column_name, sort_column_name]) |
| 97 | + check("by", by, [Hashable, list]) |
| 98 | + |
| 99 | + check_column(df, [column]) |
| 100 | + check_column(df, by) |
83 | 101 |
|
84 |
| - # Check if k is greater than 0. |
85 | 102 | if k < 1:
|
86 | 103 | raise ValueError(
|
87 |
| - "Numbers of rows per group to be returned must be greater than 0." |
| 104 | + "Numbers of rows per group " |
| 105 | + "to be returned must be greater than 0." |
88 | 106 | )
|
89 | 107 |
|
90 |
| - # Check if inplace:True in sort values kwargs because it returns None |
91 |
| - if ( |
92 |
| - "inplace" in sort_values_kwargs.keys() |
93 |
| - and sort_values_kwargs["inplace"] |
94 |
| - ): |
95 |
| - raise KeyError("Cannot use `inplace=True` in `sort_values_kwargs`.") |
| 108 | + indices = df.groupby(by=by, dropna=dropna, sort=False, observed=True) |
| 109 | + indices = indices[column] |
| 110 | + |
| 111 | + try: |
| 112 | + if ascending: |
| 113 | + indices = indices.nsmallest(n=k) |
| 114 | + else: |
| 115 | + indices = indices.nlargest(n=k) |
| 116 | + except TypeError: |
| 117 | + indices = indices.apply( |
| 118 | + lambda d: d.sort_values(ascending=ascending).head(k) |
| 119 | + ) |
96 | 120 |
|
97 |
| - return df.groupby(groupby_column_name).apply( |
98 |
| - lambda d: d.sort_values(sort_column_name, **sort_values_kwargs).head(k) |
99 |
| - ) |
| 121 | + indices = indices.index.get_level_values(-1) |
| 122 | + if ignore_index: |
| 123 | + return df.loc[indices].reset_index(drop=True) |
| 124 | + return df.loc[indices] |
0 commit comments