Skip to content

Commit 1395a50

Browse files
authored
feat: add thresh param for Dataframe.dropna (#1885)
* support thresh in dropna * update docstring, and polish function * fix mypy
1 parent 813624d commit 1395a50

File tree

4 files changed

+124
-20
lines changed

4 files changed

+124
-20
lines changed

bigframes/core/block_transforms.py

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -522,7 +522,8 @@ def rank(
522522
def dropna(
523523
block: blocks.Block,
524524
column_ids: typing.Sequence[str],
525-
how: typing.Literal["all", "any"] = "any",
525+
how: str = "any",
526+
thresh: typing.Optional[int] = None,
526527
subset: Optional[typing.Sequence[str]] = None,
527528
):
528529
"""
@@ -531,17 +532,38 @@ def dropna(
531532
if subset is None:
532533
subset = column_ids
533534

535+
# Predicates to check for non-null values in the subset of columns
534536
predicates = [
535537
ops.notnull_op.as_expr(column_id)
536538
for column_id in column_ids
537539
if column_id in subset
538540
]
541+
539542
if len(predicates) == 0:
540543
return block
541-
if how == "any":
542-
predicate = functools.reduce(ops.and_op.as_expr, predicates)
543-
else: # "all"
544-
predicate = functools.reduce(ops.or_op.as_expr, predicates)
544+
545+
if thresh is not None:
546+
# Handle single predicate case
547+
if len(predicates) == 1:
548+
count_expr = ops.AsTypeOp(pd.Int64Dtype()).as_expr(predicates[0])
549+
else:
550+
# Sum the boolean expressions to count non-null values
551+
count_expr = functools.reduce(
552+
lambda a, b: ops.add_op.as_expr(
553+
ops.AsTypeOp(pd.Int64Dtype()).as_expr(a),
554+
ops.AsTypeOp(pd.Int64Dtype()).as_expr(b),
555+
),
556+
predicates,
557+
)
558+
# Filter rows where count >= thresh
559+
predicate = ops.ge_op.as_expr(count_expr, ex.const(thresh))
560+
else:
561+
# Only handle 'how' parameter when thresh is not specified
562+
if how == "any":
563+
predicate = functools.reduce(ops.and_op.as_expr, predicates)
564+
else: # "all"
565+
predicate = functools.reduce(ops.or_op.as_expr, predicates)
566+
545567
return block.filter(predicate)
546568

547569

bigframes/dataframe.py

Lines changed: 41 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2802,6 +2802,7 @@ def dropna(
28022802
*,
28032803
axis: int | str = 0,
28042804
how: str = "any",
2805+
thresh: typing.Optional[int] = None,
28052806
subset: typing.Union[None, blocks.Label, Sequence[blocks.Label]] = None,
28062807
inplace: bool = False,
28072808
ignore_index=False,
@@ -2810,8 +2811,18 @@ def dropna(
28102811
raise NotImplementedError(
28112812
f"'inplace'=True not supported. {constants.FEEDBACK_LINK}"
28122813
)
2813-
if how not in ("any", "all"):
2814-
raise ValueError("'how' must be one of 'any', 'all'")
2814+
2815+
# Check if both thresh and how are explicitly provided
2816+
if thresh is not None:
2817+
# cannot specify both thresh and how parameters
2818+
if how != "any":
2819+
raise TypeError(
2820+
"You cannot set both the how and thresh arguments at the same time."
2821+
)
2822+
else:
2823+
# Only validate 'how' when thresh is not provided
2824+
if how not in ("any", "all"):
2825+
raise ValueError("'how' must be one of 'any', 'all'")
28152826

28162827
axis_n = utils.get_axis_number(axis)
28172828

@@ -2833,21 +2844,38 @@ def dropna(
28332844
for id_ in self._block.label_to_col_id[label]
28342845
]
28352846

2836-
result = block_ops.dropna(self._block, self._block.value_columns, how=how, subset=subset_ids) # type: ignore
2847+
result = block_ops.dropna(
2848+
self._block,
2849+
self._block.value_columns,
2850+
how=how,
2851+
thresh=thresh,
2852+
subset=subset_ids,
2853+
) # type: ignore
28372854
if ignore_index:
28382855
result = result.reset_index()
28392856
return DataFrame(result)
28402857
else:
2841-
isnull_block = self._block.multi_apply_unary_op(ops.isnull_op)
2842-
if how == "any":
2843-
null_locations = DataFrame(isnull_block).any().to_pandas()
2844-
else: # 'all'
2845-
null_locations = DataFrame(isnull_block).all().to_pandas()
2846-
keep_columns = [
2847-
col
2848-
for col, to_drop in zip(self._block.value_columns, null_locations)
2849-
if not to_drop
2850-
]
2858+
if thresh is not None:
2859+
# Keep columns with at least 'thresh' non-null values
2860+
notnull_block = self._block.multi_apply_unary_op(ops.notnull_op)
2861+
notnull_counts = DataFrame(notnull_block).sum().to_pandas()
2862+
2863+
keep_columns = [
2864+
col
2865+
for col, count in zip(self._block.value_columns, notnull_counts)
2866+
if count >= thresh
2867+
]
2868+
else:
2869+
isnull_block = self._block.multi_apply_unary_op(ops.isnull_op)
2870+
if how == "any":
2871+
null_locations = DataFrame(isnull_block).any().to_pandas()
2872+
else: # 'all'
2873+
null_locations = DataFrame(isnull_block).all().to_pandas()
2874+
keep_columns = [
2875+
col
2876+
for col, to_drop in zip(self._block.value_columns, null_locations)
2877+
if not to_drop
2878+
]
28512879
return DataFrame(self._block.select_columns(keep_columns))
28522880

28532881
def any(

tests/system/small/test_dataframe.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1207,7 +1207,7 @@ def test_assign_callable_lambda(scalars_dfs):
12071207
(1, "all", False, None),
12081208
],
12091209
)
1210-
def test_df_dropna(scalars_dfs, axis, how, ignore_index, subset):
1210+
def test_df_dropna_by_how(scalars_dfs, axis, how, ignore_index, subset):
12111211
# TODO: supply a reason why this isn't compatible with pandas 1.x
12121212
pytest.importorskip("pandas", minversion="2.0.0")
12131213
scalars_df, scalars_pandas_df = scalars_dfs
@@ -1222,6 +1222,36 @@ def test_df_dropna(scalars_dfs, axis, how, ignore_index, subset):
12221222
pandas.testing.assert_frame_equal(bf_result, pd_result)
12231223

12241224

1225+
@pytest.mark.parametrize(
1226+
("axis", "ignore_index", "subset", "thresh"),
1227+
[
1228+
(0, False, None, 2),
1229+
(0, True, None, 3),
1230+
(1, False, None, 2),
1231+
],
1232+
)
1233+
def test_df_dropna_by_thresh(scalars_dfs, axis, ignore_index, subset, thresh):
1234+
"""
1235+
Tests that dropna correctly keeps rows/columns with a minimum number
1236+
of non-null values.
1237+
"""
1238+
# TODO: supply a reason why this isn't compatible with pandas 1.x
1239+
pytest.importorskip("pandas", minversion="2.0.0")
1240+
scalars_df, scalars_pandas_df = scalars_dfs
1241+
1242+
df_result = scalars_df.dropna(
1243+
axis=axis, thresh=thresh, ignore_index=ignore_index, subset=subset
1244+
)
1245+
pd_result = scalars_pandas_df.dropna(
1246+
axis=axis, thresh=thresh, ignore_index=ignore_index, subset=subset
1247+
)
1248+
1249+
bf_result = df_result.to_pandas()
1250+
# Pandas uses int64 instead of Int64 (nullable) dtype.
1251+
pd_result.index = pd_result.index.astype(pd.Int64Dtype())
1252+
pd.testing.assert_frame_equal(bf_result, pd_result)
1253+
1254+
12251255
def test_df_dropna_range_columns(scalars_dfs):
12261256
# TODO: supply a reason why this isn't compatible with pandas 1.x
12271257
pytest.importorskip("pandas", minversion="2.0.0")

third_party/bigframes_vendored/pandas/core/frame.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1762,6 +1762,7 @@ def dropna(
17621762
*,
17631763
axis: int | str = 0,
17641764
how: str = "any",
1765+
thresh: Optional[int] = None,
17651766
subset=None,
17661767
inplace: bool = False,
17671768
ignore_index=False,
@@ -1812,6 +1813,25 @@ def dropna(
18121813
<BLANKLINE>
18131814
[3 rows x 3 columns]
18141815
1816+
Keep rows with at least 2 non-null values.
1817+
1818+
>>> df.dropna(thresh=2)
1819+
name toy born
1820+
1 Batman Batmobile 1940-04-25
1821+
2 Catwoman Bullwhip <NA>
1822+
<BLANKLINE>
1823+
[2 rows x 3 columns]
1824+
1825+
Keep columns with at least 2 non-null values:
1826+
1827+
>>> df.dropna(axis='columns', thresh=2)
1828+
name toy
1829+
0 Alfred <NA>
1830+
1 Batman Batmobile
1831+
2 Catwoman Bullwhip
1832+
<BLANKLINE>
1833+
[3 rows x 2 columns]
1834+
18151835
Define in which columns to look for missing values.
18161836
18171837
>>> df.dropna(subset=['name', 'toy'])
@@ -1822,7 +1842,7 @@ def dropna(
18221842
[2 rows x 3 columns]
18231843
18241844
Args:
1825-
axis ({0 or 'index', 1 or 'columns'}, default 'columns'):
1845+
axis ({0 or 'index', 1 or 'columns'}, default 0):
18261846
Determine if rows or columns which contain missing values are
18271847
removed.
18281848
@@ -1834,6 +1854,8 @@ def dropna(
18341854
18351855
* 'any' : If any NA values are present, drop that row or column.
18361856
* 'all' : If all values are NA, drop that row or column.
1857+
thresh (int, optional):
1858+
Require that many non-NA values. Cannot be combined with how.
18371859
subset (column label or sequence of labels, optional):
18381860
Labels along other axis to consider, e.g. if you are dropping
18391861
rows these would be a list of columns to include.
@@ -1851,6 +1873,8 @@ def dropna(
18511873
Raises:
18521874
ValueError:
18531875
If ``how`` is not one of ``any`` or ``all``.
1876+
TyperError:
1877+
If both ``how`` and ``thresh`` are specified.
18541878
"""
18551879
raise NotImplementedError(constants.ABSTRACT_METHOD_ERROR_MESSAGE)
18561880

0 commit comments

Comments
 (0)