Skip to content

Commit 44202bc

Browse files
authored
feat: add groupby head API (#791)
* feat: add groupby head API * update annotations * update order
1 parent c8d16c0 commit 44202bc

File tree

3 files changed

+62
-0
lines changed

3 files changed

+62
-0
lines changed

bigframes/core/blocks.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1384,6 +1384,26 @@ def _normalize_expression(
13841384
raise ValueError("Unexpected number of value columns.")
13851385
return expr.select_columns([*index_columns, *value_columns])
13861386

1387+
def grouped_head(
1388+
self,
1389+
by_column_ids: typing.Sequence[str],
1390+
value_columns: typing.Sequence[str],
1391+
n: int,
1392+
):
1393+
window_spec = window_specs.cumulative_rows(grouping_keys=tuple(by_column_ids))
1394+
1395+
block, result_id = self.apply_window_op(
1396+
value_columns[0],
1397+
agg_ops.rank_op,
1398+
window_spec=window_spec,
1399+
)
1400+
1401+
cond = ops.lt_op.as_expr(result_id, ex.const(n + 1))
1402+
block, cond_id = block.project_expr(cond)
1403+
block = block.filter_by_id(cond_id)
1404+
if value_columns:
1405+
return block.select_columns(value_columns)
1406+
13871407
def slice(
13881408
self,
13891409
start: typing.Optional[int] = None,

bigframes/core/groupby/__init__.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,18 @@ def __getitem__(
104104
dropna=self._dropna,
105105
)
106106

107+
def head(self, n: int = 5) -> df.DataFrame:
108+
block = self._block
109+
if self._dropna:
110+
block = block_ops.dropna(self._block, self._by_col_ids, how="any")
111+
return df.DataFrame(
112+
block.grouped_head(
113+
by_column_ids=self._by_col_ids,
114+
value_columns=self._block.value_columns,
115+
n=n,
116+
)
117+
)
118+
107119
def size(self) -> typing.Union[df.DataFrame, series.Series]:
108120
agg_block, _ = self._block.aggregate_size(
109121
by_column_ids=self._by_col_ids,
@@ -498,6 +510,16 @@ def __init__(
498510
self._value_name = value_name
499511
self._dropna = dropna # Applies to aggregations but not windowing
500512

513+
def head(self, n: int = 5) -> series.Series:
514+
block = self._block
515+
if self._dropna:
516+
block = block_ops.dropna(self._block, self._by_col_ids, how="any")
517+
return series.Series(
518+
block.grouped_head(
519+
by_column_ids=self._by_col_ids, value_columns=[self._value_column], n=n
520+
)
521+
)
522+
501523
def all(self) -> series.Series:
502524
return self._aggregate(agg_ops.all_op)
503525

tests/system/small/test_groupby.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,13 @@ def test_dataframe_groupby_numeric_aggregate(
5353
pd.testing.assert_frame_equal(pd_result, bf_result_computed, check_dtype=False)
5454

5555

56+
def test_dataframe_groupby_head(scalars_df_index, scalars_pandas_df_index):
57+
col_names = ["int64_too", "float64_col", "int64_col", "bool_col", "string_col"]
58+
bf_result = scalars_df_index[col_names].groupby("bool_col").head(2).to_pandas()
59+
pd_result = scalars_pandas_df_index[col_names].groupby("bool_col").head(2)
60+
pd.testing.assert_frame_equal(pd_result, bf_result, check_dtype=False)
61+
62+
5663
def test_dataframe_groupby_median(scalars_df_index, scalars_pandas_df_index):
5764
col_names = ["int64_too", "float64_col", "int64_col", "bool_col", "string_col"]
5865
bf_result = (
@@ -442,6 +449,19 @@ def test_series_groupby_agg_list(scalars_df_index, scalars_pandas_df_index):
442449
)
443450

444451

452+
@pytest.mark.parametrize("dropna", [True, False])
453+
def test_series_groupby_head(scalars_df_index, scalars_pandas_df_index, dropna):
454+
bf_result = (
455+
scalars_df_index.groupby("bool_col", dropna=dropna)["int64_too"]
456+
.head(1)
457+
.to_pandas()
458+
)
459+
pd_result = scalars_pandas_df_index.groupby("bool_col", dropna=dropna)[
460+
"int64_too"
461+
].head(1)
462+
pd.testing.assert_series_equal(pd_result, bf_result, check_dtype=False)
463+
464+
445465
def test_series_groupby_kurt(scalars_df_index, scalars_pandas_df_index):
446466
bf_result = (
447467
scalars_df_index["int64_too"]

0 commit comments

Comments
 (0)