Skip to content

Commit 9e7cb25

Browse files
samukwekusamuel.oranyeli
andauthored
1486 add summarise method to groupby (#1488)
* add support for groupby.summarise * add tests * remove irrelevant functions --------- Co-authored-by: samuel.oranyeli <[email protected]>
1 parent 85f126a commit 9e7cb25

File tree

2 files changed

+58
-72
lines changed

2 files changed

+58
-72
lines changed

janitor/functions/summarise.py

Lines changed: 19 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,10 @@
1414
from janitor.functions.select import get_index_labels
1515

1616

17+
@pf.register_groupby_method
1718
@pf.register_dataframe_method
1819
def summarise(
19-
df: pd.DataFrame,
20+
df: pd.DataFrame | DataFrameGroupBy,
2021
*args: tuple[dict | tuple],
2122
by: Any = None,
2223
) -> pd.DataFrame:
@@ -107,6 +108,8 @@ def summarise(
107108
Arguments supported in `pd.DataFrame.groupby`
108109
can also be passed to `by` via a dictionary.
109110
111+
If `df` is a `DataFrameGroupBy` object, `by` is ignored.
112+
110113
Examples:
111114
>>> import pandas as pd
112115
>>> import janitor
@@ -160,7 +163,7 @@ def summarise(
160163
103202 4.0
161164
162165
Args:
163-
df: A pandas DataFrame.
166+
df: A pandas DataFrame or DataFrameGroupBy object.
164167
args: Either a dictionary or a tuple.
165168
by: Column(s) to group by.
166169
@@ -171,8 +174,10 @@ def summarise(
171174
A pandas DataFrame with aggregated columns.
172175
173176
""" # noqa: E501
174-
175-
if by is not None:
177+
if isinstance(df, DataFrameGroupBy):
178+
by = df
179+
df = df.obj
180+
elif by is not None:
176181
# it is assumed that by is created from df
177182
# onus is on user to ensure that
178183
if isinstance(by, DataFrameGroupBy):
@@ -233,7 +238,7 @@ def _aggfunc(arg, df, by):
233238
val = df
234239
else:
235240
val = by
236-
outcome = _process_maybe_callable(func=arg, obj=val)
241+
outcome = apply_if_callable(maybe_callable=arg, obj=val)
237242
if isinstance(outcome, pd.Series):
238243
if not outcome.name:
239244
raise ValueError("Ensure the pandas Series object has a name")
@@ -270,10 +275,11 @@ def _(arg, df, by):
270275
if len(aggfunc) != 2:
271276
raise ValueError("the tuple has to be a length of 2")
272277
column, func = aggfunc
273-
column_ = _handle_tuple_groupby_selection(by=by, column=column)
274-
column = _apply_func_to_obj(aggfunc=func, obj=val[column_])
275-
if isinstance(column, pd.DataFrame) and column.shape[-1] == 1:
278+
column = val.agg({column: func})
279+
try:
276280
column = column.squeeze()
281+
except AttributeError:
282+
pass
277283
column = _convert_obj_to_named_series(
278284
obj=column,
279285
column_name=column_name,
@@ -285,54 +291,20 @@ def _(arg, df, by):
285291
f"instead got {type(column)}"
286292
)
287293
else:
288-
column_ = _handle_tuple_groupby_selection(
289-
by=by, column=column_name
290-
)
291-
column = _apply_func_to_obj(aggfunc=aggfunc, obj=val[column_])
294+
column = val.agg({column_name: aggfunc})
295+
try:
296+
column = column.squeeze()
297+
except AttributeError:
298+
pass
292299
column = _convert_obj_to_named_series(
293300
obj=column,
294301
column_name=column_name,
295302
function=aggfunc,
296303
)
297-
column = _rename_column_in_by(
298-
column=column, column_name=column_name, by=by
299-
)
300304
contents.append(column)
301305
return contents
302306

303307

304-
def _process_maybe_callable(func: callable, obj):
305-
"""Function to handle callables"""
306-
try:
307-
column = obj.agg(func)
308-
except: # noqa: E722
309-
column = apply_if_callable(maybe_callable=func, obj=obj)
310-
return column
311-
312-
313-
def _process_maybe_string(func: str, obj):
314-
"""Function to handle pandas string functions"""
315-
# treat as a pandas approved string function
316-
# https://pandas.pydata.org/docs/user_guide/groupby.html#built-in-aggregation-methods
317-
return obj.agg(func)
318-
319-
320-
def _apply_func_to_obj(aggfunc, obj):
321-
"""Handle str/callables within a dictionary"""
322-
if isinstance(aggfunc, str):
323-
return _process_maybe_string(func=aggfunc, obj=obj)
324-
return _process_maybe_callable(func=aggfunc, obj=obj)
325-
326-
327-
def _handle_tuple_groupby_selection(by: Any, column: Any):
328-
"""
329-
Properly handle a tuple column selection in the presence of a groupby
330-
"""
331-
if (by is not None) and isinstance(column, tuple):
332-
return [column]
333-
return column
334-
335-
336308
def _convert_obj_to_named_series(obj, function: Any, column_name: Any):
337309
if isinstance(obj, pd.Series):
338310
obj.name = column_name
@@ -344,12 +316,3 @@ def _convert_obj_to_named_series(obj, function: Any, column_name: Any):
344316
else:
345317
function_name = function.__name__
346318
return pd.Series(data=obj, index=[function_name], name=column_name)
347-
348-
349-
def _rename_column_in_by(column, column_name, by):
350-
if by is None:
351-
return column
352-
elif isinstance(column, pd.DataFrame) and is_scalar(column_name):
353-
columns = pd.MultiIndex.from_product([[column_name], column.columns])
354-
column.columns = columns
355-
return column

tests/functions/test_summarise.py

Lines changed: 39 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,14 @@ def test_summarise_by_callable_grp(df_summarise):
7777
assert_frame_equal(actual, expected)
7878

7979

80+
def test_summarise_by_callable_grp_grouped(df_summarise):
81+
"""Test output for a callable"""
82+
grp = df_summarise.groupby("combine_id")
83+
actual = grp.summarise(lambda df: df.sum())
84+
expected = grp.sum()
85+
assert_frame_equal(actual, expected)
86+
87+
8088
def test_summarise_dict_df_str(df_summarise):
8189
"""Test output for a dictionary"""
8290
actual = df_summarise.summarise({"avg_run": "mean"})
@@ -164,6 +172,13 @@ def test_summarise_by_tuple(df_summarise):
164172
assert_frame_equal(actual, expected)
165173

166174

175+
def test_summarise_by_tuple_grouped(df_summarise):
176+
"""Test output for a tuple"""
177+
actual = df_summarise.groupby("combine_id").summarise(("avg_run", "mean"))
178+
expected = df_summarise.groupby("combine_id").agg({"avg_run": "mean"})
179+
assert_frame_equal(actual, expected)
180+
181+
167182
def test_summarise_tuple_df_callable(df_summarise):
168183
"""Test output for a tuple"""
169184
actual = df_summarise.summarise(("avg_run", lambda df: df.sum()))
@@ -180,24 +195,12 @@ def test_summarise_tuple_by_callable(df_summarise):
180195
assert_frame_equal(actual, expected)
181196

182197

183-
def test_summarise_tuple_by_callable_dataframe(df_summarise):
184-
"""Test output for a tuple"""
185-
actual = df_summarise.summarise(
186-
("avg_run", lambda df: df.agg(["sum", "mean"])), by="combine_id"
187-
)
188-
expected = df_summarise.groupby("combine_id").agg(
189-
{"avg_run": ["sum", "mean"]}
190-
)
191-
assert_frame_equal(actual, expected)
192-
193-
194-
def test_summarise_tuple_grouped_object(df_summarise):
198+
def test_summarise_tuple_by_callable_grouped(df_summarise):
195199
"""Test output for a tuple"""
196-
grp = df_summarise.groupby("combine_id")
197-
actual = df_summarise.summarise(
198-
("avg_run", lambda df: df.agg(["sum", "mean"])), by=grp
200+
actual = df_summarise.groupby("combine_id").summarise(
201+
("avg_run", lambda df: df.sum())
199202
)
200-
expected = grp.agg({"avg_run": ["sum", "mean"]})
203+
expected = df_summarise.groupby("combine_id").agg({"avg_run": "sum"})
201204
assert_frame_equal(actual, expected)
202205

203206

@@ -273,6 +276,26 @@ def test_summarise_MI_different_levels_tuple(dfmi):
273276
assert_frame_equal(actual, expected)
274277

275278

279+
def test_summarise_MI_different_levels_tuple_grouped(dfmi):
280+
"""Test summarise on a MultiIndex"""
281+
actual = dfmi.groupby(level="A").summarise(
282+
{("a", "bar"): "sum", ("rar",): (("a", "foo"), "mean")},
283+
("b", "min"),
284+
)
285+
actual.columns = ["A", "B", "C", "D"]
286+
grp = dfmi.groupby(level="A")
287+
expected = grp.agg(
288+
{
289+
("a", "bar"): "sum",
290+
("a", "foo"): "mean",
291+
("b", "bah"): "min",
292+
("b", "foo"): "min",
293+
}
294+
)
295+
expected.columns = ["A", "B", "C", "D"]
296+
assert_frame_equal(actual, expected)
297+
298+
276299
def test_summarise_MI_different_levels_dataframe(dfmi):
277300
"""raise if dictionary value is a tuple and the returned aggregate is a DataFrame"""
278301
with pytest.raises(TypeError, match="Expected a pandas Series object;.+"):

0 commit comments

Comments
 (0)