Skip to content

Commit 41dda88

Browse files
feat: Add first, last support to GroupBy (#1969)
1 parent d17b711 commit 41dda88

File tree

6 files changed

+278
-2
lines changed

6 files changed

+278
-2
lines changed

bigframes/core/compile/compiled.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -459,7 +459,7 @@ def project_window_op(
459459
for column in inputs:
460460
clauses.append((column.isnull(), ibis_types.null()))
461461
if window_spec.min_periods and len(inputs) > 0:
462-
if expression.op.skips_nulls:
462+
if not expression.op.nulls_count_for_min_values:
463463
# Most operations do not count NULL values towards min_periods
464464
per_col_does_count = (column.notnull() for column in inputs)
465465
# All inputs must be non-null for observation to count

bigframes/core/groupby/dataframe_group_by.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,48 @@ def kurt(
263263

264264
kurtosis = kurt
265265

266+
@validations.requires_ordering()
267+
def first(self, numeric_only: bool = False, min_count: int = -1) -> df.DataFrame:
268+
window_spec = window_specs.unbound(
269+
grouping_keys=tuple(self._by_col_ids),
270+
min_periods=min_count if min_count >= 0 else 0,
271+
)
272+
target_cols, index = self._aggregated_columns(numeric_only)
273+
block, firsts_ids = self._block.multi_apply_window_op(
274+
target_cols,
275+
agg_ops.FirstNonNullOp(),
276+
window_spec=window_spec,
277+
)
278+
block, _ = block.aggregate(
279+
self._by_col_ids,
280+
tuple(
281+
aggs.agg(firsts_id, agg_ops.AnyValueOp()) for firsts_id in firsts_ids
282+
),
283+
dropna=self._dropna,
284+
column_labels=index,
285+
)
286+
return df.DataFrame(block)
287+
288+
@validations.requires_ordering()
289+
def last(self, numeric_only: bool = False, min_count: int = -1) -> df.DataFrame:
290+
window_spec = window_specs.unbound(
291+
grouping_keys=tuple(self._by_col_ids),
292+
min_periods=min_count if min_count >= 0 else 0,
293+
)
294+
target_cols, index = self._aggregated_columns(numeric_only)
295+
block, lasts_ids = self._block.multi_apply_window_op(
296+
target_cols,
297+
agg_ops.LastNonNullOp(),
298+
window_spec=window_spec,
299+
)
300+
block, _ = block.aggregate(
301+
self._by_col_ids,
302+
tuple(aggs.agg(lasts_id, agg_ops.AnyValueOp()) for lasts_id in lasts_ids),
303+
dropna=self._dropna,
304+
column_labels=index,
305+
)
306+
return df.DataFrame(block)
307+
266308
def all(self) -> df.DataFrame:
267309
return self._aggregate_all(agg_ops.all_op)
268310

bigframes/core/groupby/series_group_by.py

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
import bigframes.core.window as windows
3737
import bigframes.core.window_spec as window_specs
3838
import bigframes.dataframe as df
39+
import bigframes.dtypes
3940
import bigframes.operations.aggregations as agg_ops
4041
import bigframes.series as series
4142

@@ -162,6 +163,54 @@ def kurt(self, *args, **kwargs) -> series.Series:
162163

163164
kurtosis = kurt
164165

166+
@validations.requires_ordering()
167+
def first(self, numeric_only: bool = False, min_count: int = -1) -> series.Series:
168+
if numeric_only and not bigframes.dtypes.is_numeric(
169+
self._block.expr.get_column_type(self._value_column)
170+
):
171+
raise TypeError(
172+
f"Cannot use 'numeric_only' with non-numeric column {self._value_name}."
173+
)
174+
window_spec = window_specs.unbound(
175+
grouping_keys=tuple(self._by_col_ids),
176+
min_periods=min_count if min_count >= 0 else 0,
177+
)
178+
block, firsts_id = self._block.apply_window_op(
179+
self._value_column,
180+
agg_ops.FirstNonNullOp(),
181+
window_spec=window_spec,
182+
)
183+
block, _ = block.aggregate(
184+
self._by_col_ids,
185+
(aggs.agg(firsts_id, agg_ops.AnyValueOp()),),
186+
dropna=self._dropna,
187+
)
188+
return series.Series(block.with_column_labels([self._value_name]))
189+
190+
@validations.requires_ordering()
191+
def last(self, numeric_only: bool = False, min_count: int = -1) -> series.Series:
192+
if numeric_only and not bigframes.dtypes.is_numeric(
193+
self._block.expr.get_column_type(self._value_column)
194+
):
195+
raise TypeError(
196+
f"Cannot use 'numeric_only' with non-numeric column {self._value_name}."
197+
)
198+
window_spec = window_specs.unbound(
199+
grouping_keys=tuple(self._by_col_ids),
200+
min_periods=min_count if min_count >= 0 else 0,
201+
)
202+
block, firsts_id = self._block.apply_window_op(
203+
self._value_column,
204+
agg_ops.LastNonNullOp(),
205+
window_spec=window_spec,
206+
)
207+
block, _ = block.aggregate(
208+
self._by_col_ids,
209+
(aggs.agg(firsts_id, agg_ops.AnyValueOp()),),
210+
dropna=self._dropna,
211+
)
212+
return series.Series(block.with_column_labels([self._value_name]))
213+
165214
def prod(self, *args) -> series.Series:
166215
return self._aggregate(agg_ops.product_op)
167216

@@ -314,7 +363,7 @@ def _apply_window_op(
314363
discard_name=False,
315364
window: typing.Optional[window_specs.WindowSpec] = None,
316365
never_skip_nulls: bool = False,
317-
):
366+
) -> series.Series:
318367
"""Apply window op to groupby. Defaults to grouped cumulative window."""
319368
window_spec = window or window_specs.cumulative_rows(
320369
grouping_keys=tuple(self._by_col_ids)

bigframes/operations/aggregations.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,11 @@ def skips_nulls(self):
3333
"""Whether the window op skips null rows."""
3434
return True
3535

36+
@property
37+
def nulls_count_for_min_values(self) -> bool:
38+
"""Whether null values count for min_values."""
39+
return not self.skips_nulls
40+
3641
@property
3742
def implicitly_inherits_order(self):
3843
"""
@@ -480,6 +485,10 @@ class FirstNonNullOp(UnaryWindowOp):
480485
def skips_nulls(self):
481486
return False
482487

488+
@property
489+
def nulls_count_for_min_values(self) -> bool:
490+
return False
491+
483492

484493
@dataclasses.dataclass(frozen=True)
485494
class LastOp(UnaryWindowOp):
@@ -492,6 +501,10 @@ class LastNonNullOp(UnaryWindowOp):
492501
def skips_nulls(self):
493502
return False
494503

504+
@property
505+
def nulls_count_for_min_values(self) -> bool:
506+
return False
507+
495508

496509
@dataclasses.dataclass(frozen=True)
497510
class ShiftOp(UnaryWindowOp):

tests/system/small/test_groupby.py

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -768,3 +768,101 @@ def test_series_groupby_quantile(scalars_df_index, scalars_pandas_df_index, q):
768768
pd.testing.assert_series_equal(
769769
pd_result, bf_result, check_dtype=False, check_index_type=False
770770
)
771+
772+
773+
@pytest.mark.parametrize(
774+
("numeric_only", "min_count"),
775+
[
776+
(True, 2),
777+
(False, -1),
778+
],
779+
)
780+
def test_series_groupby_first(
781+
scalars_df_index, scalars_pandas_df_index, numeric_only, min_count
782+
):
783+
bf_result = (
784+
scalars_df_index.groupby("string_col")["int64_col"].first(
785+
numeric_only=numeric_only, min_count=min_count
786+
)
787+
).to_pandas()
788+
pd_result = scalars_pandas_df_index.groupby("string_col")["int64_col"].first(
789+
numeric_only=numeric_only, min_count=min_count
790+
)
791+
pd.testing.assert_series_equal(
792+
pd_result,
793+
bf_result,
794+
)
795+
796+
797+
@pytest.mark.parametrize(
798+
("numeric_only", "min_count"),
799+
[
800+
(False, 4),
801+
(True, 0),
802+
],
803+
)
804+
def test_series_groupby_last(
805+
scalars_df_index, scalars_pandas_df_index, numeric_only, min_count
806+
):
807+
bf_result = (
808+
scalars_df_index.groupby("string_col")["int64_col"].last(
809+
numeric_only=numeric_only, min_count=min_count
810+
)
811+
).to_pandas()
812+
pd_result = scalars_pandas_df_index.groupby("string_col")["int64_col"].last(
813+
numeric_only=numeric_only, min_count=min_count
814+
)
815+
pd.testing.assert_series_equal(pd_result, bf_result)
816+
817+
818+
@pytest.mark.parametrize(
819+
("numeric_only", "min_count"),
820+
[
821+
(False, 4),
822+
(True, 0),
823+
],
824+
)
825+
def test_dataframe_groupby_first(
826+
scalars_df_index, scalars_pandas_df_index, numeric_only, min_count
827+
):
828+
# min_count seems to not work properly on older pandas
829+
pytest.importorskip("pandas", minversion="2.0.0")
830+
# bytes, dates not handling min_count properly in pandas
831+
bf_result = (
832+
scalars_df_index.drop(columns=["bytes_col", "date_col"])
833+
.groupby(scalars_df_index.int64_col % 2)
834+
.first(numeric_only=numeric_only, min_count=min_count)
835+
).to_pandas()
836+
pd_result = (
837+
scalars_pandas_df_index.drop(columns=["bytes_col", "date_col"])
838+
.groupby(scalars_pandas_df_index.int64_col % 2)
839+
.first(numeric_only=numeric_only, min_count=min_count)
840+
)
841+
pd.testing.assert_frame_equal(
842+
pd_result,
843+
bf_result,
844+
)
845+
846+
847+
@pytest.mark.parametrize(
848+
("numeric_only", "min_count"),
849+
[
850+
(True, 2),
851+
(False, -1),
852+
],
853+
)
854+
def test_dataframe_groupby_last(
855+
scalars_df_index, scalars_pandas_df_index, numeric_only, min_count
856+
):
857+
bf_result = (
858+
scalars_df_index.groupby(scalars_df_index.int64_col % 2).last(
859+
numeric_only=numeric_only, min_count=min_count
860+
)
861+
).to_pandas()
862+
pd_result = scalars_pandas_df_index.groupby(
863+
scalars_pandas_df_index.int64_col % 2
864+
).last(numeric_only=numeric_only, min_count=min_count)
865+
pd.testing.assert_frame_equal(
866+
pd_result,
867+
bf_result,
868+
)

third_party/bigframes_vendored/pandas/core/groupby/__init__.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -537,6 +537,80 @@ def kurtosis(
537537
"""
538538
raise NotImplementedError(constants.ABSTRACT_METHOD_ERROR_MESSAGE)
539539

540+
def first(self, numeric_only: bool = False, min_count: int = -1):
541+
"""
542+
Compute the first entry of each column within each group.
543+
544+
Defaults to skipping NA elements.
545+
546+
**Examples:**
547+
>>> import bigframes.pandas as bpd
548+
>>> bpd.options.display.progress_bar = None
549+
550+
>>> df = bpd.DataFrame(dict(A=[1, 1, 3], B=[None, 5, 6], C=[1, 2, 3]))
551+
>>> df.groupby("A").first()
552+
B C
553+
A
554+
1 5.0 1
555+
3 6.0 3
556+
<BLANKLINE>
557+
[2 rows x 2 columns]
558+
559+
>>> df.groupby("A").first(min_count=2)
560+
B C
561+
A
562+
1 <NA> 1
563+
3 <NA> <NA>
564+
<BLANKLINE>
565+
[2 rows x 2 columns]
566+
567+
Args:
568+
numeric_only (bool, default False):
569+
Include only float, int, boolean columns. If None, will attempt to use
570+
everything, then use only numeric data.
571+
min_count (int, default -1):
572+
The required number of valid values to perform the operation. If fewer
573+
than ``min_count`` valid values are present the result will be NA.
574+
575+
Returns:
576+
bigframes.pandas.DataFrame or bigframes.pandas.Series:
577+
First of values within each group.
578+
"""
579+
raise NotImplementedError(constants.ABSTRACT_METHOD_ERROR_MESSAGE)
580+
581+
def last(self, numeric_only: bool = False, min_count: int = -1):
582+
"""
583+
Compute the last entry of each column within each group.
584+
585+
Defaults to skipping NA elements.
586+
587+
**Examples:**
588+
>>> import bigframes.pandas as bpd
589+
>>> bpd.options.display.progress_bar = None
590+
591+
>>> df = bpd.DataFrame(dict(A=[1, 1, 3], B=[5, None, 6], C=[1, 2, 3]))
592+
>>> df.groupby("A").last()
593+
B C
594+
A
595+
1 5.0 2
596+
3 6.0 3
597+
<BLANKLINE>
598+
[2 rows x 2 columns]
599+
600+
Args:
601+
numeric_only (bool, default False):
602+
Include only float, int, boolean columns. If None, will attempt to use
603+
everything, then use only numeric data.
604+
min_count (int, default -1):
605+
The required number of valid values to perform the operation. If fewer
606+
than ``min_count`` valid values are present the result will be NA.
607+
608+
Returns:
609+
bigframes.pandas.DataFrame or bigframes.pandas.Series:
610+
Last of values within each group.
611+
"""
612+
raise NotImplementedError(constants.ABSTRACT_METHOD_ERROR_MESSAGE)
613+
540614
def sum(
541615
self,
542616
numeric_only: bool = False,

0 commit comments

Comments
 (0)