Skip to content

Commit 9d4504b

Browse files
authored
feat: Support callable for dataframe mask method (#2020)
1 parent 8f2cad2 commit 9d4504b

File tree

4 files changed

+81
-21
lines changed

4 files changed

+81
-21
lines changed

bigframes/dataframe.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2828,6 +2828,19 @@ def itertuples(
28282828
for item in df.itertuples(index=index, name=name):
28292829
yield item
28302830

2831+
def _apply_callable(self, condition):
2832+
"""Executes the possible callable condition as needed."""
2833+
if callable(condition):
2834+
# When it's a bigframes function.
2835+
if hasattr(condition, "bigframes_bigquery_function"):
2836+
return self.apply(condition, axis=1)
2837+
2838+
# When it's a plain Python function.
2839+
return condition(self)
2840+
2841+
# When it's not a callable.
2842+
return condition
2843+
28312844
def where(self, cond, other=None):
28322845
if isinstance(other, bigframes.series.Series):
28332846
raise ValueError("Seires is not a supported replacement type!")
@@ -2839,16 +2852,8 @@ def where(self, cond, other=None):
28392852

28402853
# Execute it with the DataFrame when cond or/and other is callable.
28412854
# It can be either a plain python function or remote/managed function.
2842-
if callable(cond):
2843-
if hasattr(cond, "bigframes_bigquery_function"):
2844-
cond = self.apply(cond, axis=1)
2845-
else:
2846-
cond = cond(self)
2847-
if callable(other):
2848-
if hasattr(other, "bigframes_bigquery_function"):
2849-
other = self.apply(other, axis=1)
2850-
else:
2851-
other = other(self)
2855+
cond = self._apply_callable(cond)
2856+
other = self._apply_callable(other)
28522857

28532858
aligned_block, (_, _) = self._block.join(cond._block, how="left")
28542859
# No left join is needed when 'other' is None or constant.
@@ -2899,7 +2904,7 @@ def where(self, cond, other=None):
28992904
return result
29002905

29012906
def mask(self, cond, other=None):
2902-
return self.where(~cond, other=other)
2907+
return self.where(~self._apply_callable(cond), other=other)
29032908

29042909
def dropna(
29052910
self,

tests/system/large/functions/test_managed_function.py

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -965,7 +965,7 @@ def float_parser(row):
965965
)
966966

967967

968-
def test_managed_function_df_where(session, dataset_id, scalars_dfs):
968+
def test_managed_function_df_where_mask(session, dataset_id, scalars_dfs):
969969
try:
970970

971971
# The return type has to be bool type for callable where condition.
@@ -987,15 +987,15 @@ def is_sum_positive(a, b):
987987
pd_int64_df = scalars_pandas_df[int64_cols]
988988
pd_int64_df_filtered = pd_int64_df.dropna()
989989

990-
# Use callable condition in dataframe.where method.
990+
# Test callable condition in dataframe.where method.
991991
bf_result = bf_int64_df_filtered.where(is_sum_positive_mf).to_pandas()
992992
# Pandas doesn't support such case, use following as workaround.
993993
pd_result = pd_int64_df_filtered.where(pd_int64_df_filtered.sum(axis=1) > 0)
994994

995995
# Ignore any dtype difference.
996996
pandas.testing.assert_frame_equal(bf_result, pd_result, check_dtype=False)
997997

998-
# Make sure the read_gbq_function path works for this function.
998+
# Make sure the read_gbq_function path works for dataframe.where method.
999999
is_sum_positive_ref = session.read_gbq_function(
10001000
function_name=is_sum_positive_mf.bigframes_bigquery_function
10011001
)
@@ -1012,14 +1012,27 @@ def is_sum_positive(a, b):
10121012
bf_result_gbq, pd_result_gbq, check_dtype=False
10131013
)
10141014

1015+
# Test callable condition in dataframe.mask method.
1016+
bf_result_gbq = bf_int64_df_filtered.mask(
1017+
is_sum_positive_ref, -bf_int64_df_filtered
1018+
).to_pandas()
1019+
pd_result_gbq = pd_int64_df_filtered.mask(
1020+
pd_int64_df_filtered.sum(axis=1) > 0, -pd_int64_df_filtered
1021+
)
1022+
1023+
# Ignore any dtype difference.
1024+
pandas.testing.assert_frame_equal(
1025+
bf_result_gbq, pd_result_gbq, check_dtype=False
1026+
)
1027+
10151028
finally:
10161029
# Clean up the gcp assets created for the managed function.
10171030
cleanup_function_assets(
10181031
is_sum_positive_mf, session.bqclient, ignore_failures=False
10191032
)
10201033

10211034

1022-
def test_managed_function_df_where_series(session, dataset_id, scalars_dfs):
1035+
def test_managed_function_df_where_mask_series(session, dataset_id, scalars_dfs):
10231036
try:
10241037

10251038
# The return type has to be bool type for callable where condition.
@@ -1041,14 +1054,14 @@ def is_sum_positive_series(s):
10411054
pd_int64_df = scalars_pandas_df[int64_cols]
10421055
pd_int64_df_filtered = pd_int64_df.dropna()
10431056

1044-
# Use callable condition in dataframe.where method.
1057+
# Test callable condition in dataframe.where method.
10451058
bf_result = bf_int64_df_filtered.where(is_sum_positive_series).to_pandas()
10461059
pd_result = pd_int64_df_filtered.where(is_sum_positive_series)
10471060

10481061
# Ignore any dtype difference.
10491062
pandas.testing.assert_frame_equal(bf_result, pd_result, check_dtype=False)
10501063

1051-
# Make sure the read_gbq_function path works for this function.
1064+
# Make sure the read_gbq_function path works for dataframe.where method.
10521065
is_sum_positive_series_ref = session.read_gbq_function(
10531066
function_name=is_sum_positive_series_mf.bigframes_bigquery_function,
10541067
is_row_processor=True,
@@ -1070,6 +1083,19 @@ def func_for_other(x):
10701083
bf_result_gbq, pd_result_gbq, check_dtype=False
10711084
)
10721085

1086+
# Test callable condition in dataframe.mask method.
1087+
bf_result_gbq = bf_int64_df_filtered.mask(
1088+
is_sum_positive_series_ref, func_for_other
1089+
).to_pandas()
1090+
pd_result_gbq = pd_int64_df_filtered.mask(
1091+
is_sum_positive_series, func_for_other
1092+
)
1093+
1094+
# Ignore any dtype difference.
1095+
pandas.testing.assert_frame_equal(
1096+
bf_result_gbq, pd_result_gbq, check_dtype=False
1097+
)
1098+
10731099
finally:
10741100
# Clean up the gcp assets created for the managed function.
10751101
cleanup_function_assets(

tests/system/large/functions/test_remote_function.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2850,7 +2850,7 @@ def foo(x: int) -> int:
28502850

28512851

28522852
@pytest.mark.flaky(retries=2, delay=120)
2853-
def test_remote_function_df_where(session, dataset_id, scalars_dfs):
2853+
def test_remote_function_df_where_mask(session, dataset_id, scalars_dfs):
28542854
try:
28552855

28562856
# The return type has to be bool type for callable where condition.
@@ -2873,14 +2873,22 @@ def is_sum_positive(a, b):
28732873
pd_int64_df = scalars_pandas_df[int64_cols]
28742874
pd_int64_df_filtered = pd_int64_df.dropna()
28752875

2876-
# Use callable condition in dataframe.where method.
2876+
# Test callable condition in dataframe.where method.
28772877
bf_result = bf_int64_df_filtered.where(is_sum_positive_mf, 0).to_pandas()
28782878
# Pandas doesn't support such case, use following as workaround.
28792879
pd_result = pd_int64_df_filtered.where(pd_int64_df_filtered.sum(axis=1) > 0, 0)
28802880

28812881
# Ignore any dtype difference.
28822882
pandas.testing.assert_frame_equal(bf_result, pd_result, check_dtype=False)
28832883

2884+
# Test callable condition in dataframe.mask method.
2885+
bf_result = bf_int64_df_filtered.mask(is_sum_positive_mf, 0).to_pandas()
2886+
# Pandas doesn't support such case, use following as workaround.
2887+
pd_result = pd_int64_df_filtered.mask(pd_int64_df_filtered.sum(axis=1) > 0, 0)
2888+
2889+
# Ignore any dtype difference.
2890+
pandas.testing.assert_frame_equal(bf_result, pd_result, check_dtype=False)
2891+
28842892
finally:
28852893
# Clean up the gcp assets created for the remote function.
28862894
cleanup_function_assets(
@@ -2889,7 +2897,7 @@ def is_sum_positive(a, b):
28892897

28902898

28912899
@pytest.mark.flaky(retries=2, delay=120)
2892-
def test_remote_function_df_where_series(session, dataset_id, scalars_dfs):
2900+
def test_remote_function_df_where_mask_series(session, dataset_id, scalars_dfs):
28932901
try:
28942902

28952903
# The return type has to be bool type for callable where condition.
@@ -2916,7 +2924,7 @@ def is_sum_positive_series(s):
29162924
def func_for_other(x):
29172925
return -x
29182926

2919-
# Use callable condition in dataframe.where method.
2927+
# Test callable condition in dataframe.where method.
29202928
bf_result = bf_int64_df_filtered.where(
29212929
is_sum_positive_series, func_for_other
29222930
).to_pandas()
@@ -2925,6 +2933,15 @@ def func_for_other(x):
29252933
# Ignore any dtype difference.
29262934
pandas.testing.assert_frame_equal(bf_result, pd_result, check_dtype=False)
29272935

2936+
# Test callable condition in dataframe.mask method.
2937+
bf_result = bf_int64_df_filtered.mask(
2938+
is_sum_positive_series_mf, func_for_other
2939+
).to_pandas()
2940+
pd_result = pd_int64_df_filtered.mask(is_sum_positive_series, func_for_other)
2941+
2942+
# Ignore any dtype difference.
2943+
pandas.testing.assert_frame_equal(bf_result, pd_result, check_dtype=False)
2944+
29282945
finally:
29292946
# Clean up the gcp assets created for the remote function.
29302947
cleanup_function_assets(

tests/system/small/test_dataframe.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -406,6 +406,18 @@ def test_mask_series_cond(scalars_df_index, scalars_pandas_df_index):
406406
pandas.testing.assert_frame_equal(bf_result, pd_result)
407407

408408

409+
def test_mask_callable(scalars_df_index, scalars_pandas_df_index):
410+
def is_positive(x):
411+
return x > 0
412+
413+
bf_df = scalars_df_index[["int64_too", "int64_col", "float64_col"]]
414+
pd_df = scalars_pandas_df_index[["int64_too", "int64_col", "float64_col"]]
415+
bf_result = bf_df.mask(cond=is_positive, other=lambda x: x + 1).to_pandas()
416+
pd_result = pd_df.mask(cond=is_positive, other=lambda x: x + 1)
417+
418+
pandas.testing.assert_frame_equal(bf_result, pd_result)
419+
420+
409421
def test_where_multi_column(scalars_df_index, scalars_pandas_df_index):
410422
# Test when a dataframe has multi-columns.
411423
columns = ["int64_col", "float64_col"]

0 commit comments

Comments
 (0)