Skip to content

Commit a441756

Browse files
committed
feat: Support callable for dataframe mask method
1 parent 0c0c3fa commit a441756

File tree

4 files changed

+81
-21
lines changed

4 files changed

+81
-21
lines changed

bigframes/dataframe.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2787,6 +2787,19 @@ def itertuples(
27872787
for item in df.itertuples(index=index, name=name):
27882788
yield item
27892789

2790+
def _apply_callable(self, condition):
2791+
"""Executes the possible callable condition as needed."""
2792+
if callable(condition):
2793+
# When it's a bigframes function.
2794+
if hasattr(condition, "bigframes_bigquery_function"):
2795+
return self.apply(condition, axis=1)
2796+
2797+
# When it's a plain Python function.
2798+
return condition(self)
2799+
2800+
# When it's not a callable.
2801+
return condition
2802+
27902803
def where(self, cond, other=None):
27912804
if isinstance(other, bigframes.series.Series):
27922805
raise ValueError("Seires is not a supported replacement type!")
@@ -2798,16 +2811,8 @@ def where(self, cond, other=None):
27982811

27992812
# Execute it with the DataFrame when cond or/and other is callable.
28002813
# It can be either a plain python function or remote/managed function.
2801-
if callable(cond):
2802-
if hasattr(cond, "bigframes_bigquery_function"):
2803-
cond = self.apply(cond, axis=1)
2804-
else:
2805-
cond = cond(self)
2806-
if callable(other):
2807-
if hasattr(other, "bigframes_bigquery_function"):
2808-
other = self.apply(other, axis=1)
2809-
else:
2810-
other = other(self)
2814+
cond = self._apply_callable(cond)
2815+
other = self._apply_callable(other)
28112816

28122817
aligned_block, (_, _) = self._block.join(cond._block, how="left")
28132818
# No left join is needed when 'other' is None or constant.
@@ -2858,7 +2863,7 @@ def where(self, cond, other=None):
28582863
return result
28592864

28602865
def mask(self, cond, other=None):
2861-
return self.where(~cond, other=other)
2866+
return self.where(~self._apply_callable(cond), other=other)
28622867

28632868
def dropna(
28642869
self,

tests/system/large/functions/test_managed_function.py

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -965,7 +965,7 @@ def float_parser(row):
965965
)
966966

967967

968-
def test_managed_function_df_where(session, dataset_id, scalars_dfs):
968+
def test_managed_function_df_where_mask(session, dataset_id, scalars_dfs):
969969
try:
970970

971971
# The return type has to be bool type for callable where condition.
@@ -987,15 +987,15 @@ def is_sum_positive(a, b):
987987
pd_int64_df = scalars_pandas_df[int64_cols]
988988
pd_int64_df_filtered = pd_int64_df.dropna()
989989

990-
# Use callable condition in dataframe.where method.
990+
# Test callable condition in dataframe.where method.
991991
bf_result = bf_int64_df_filtered.where(is_sum_positive_mf).to_pandas()
992992
# Pandas doesn't support such case, use following as workaround.
993993
pd_result = pd_int64_df_filtered.where(pd_int64_df_filtered.sum(axis=1) > 0)
994994

995995
# Ignore any dtype difference.
996996
pandas.testing.assert_frame_equal(bf_result, pd_result, check_dtype=False)
997997

998-
# Make sure the read_gbq_function path works for this function.
998+
# Make sure the read_gbq_function path works for dataframe.where method.
999999
is_sum_positive_ref = session.read_gbq_function(
10001000
function_name=is_sum_positive_mf.bigframes_bigquery_function
10011001
)
@@ -1012,14 +1012,27 @@ def is_sum_positive(a, b):
10121012
bf_result_gbq, pd_result_gbq, check_dtype=False
10131013
)
10141014

1015+
# Test callable condition in dataframe.mask method.
1016+
bf_result_gbq = bf_int64_df_filtered.mask(
1017+
is_sum_positive_ref, -bf_int64_df_filtered
1018+
).to_pandas()
1019+
pd_result_gbq = pd_int64_df_filtered.mask(
1020+
pd_int64_df_filtered.sum(axis=1) > 0, -pd_int64_df_filtered
1021+
)
1022+
1023+
# Ignore any dtype difference.
1024+
pandas.testing.assert_frame_equal(
1025+
bf_result_gbq, pd_result_gbq, check_dtype=False
1026+
)
1027+
10151028
finally:
10161029
# Clean up the gcp assets created for the managed function.
10171030
cleanup_function_assets(
10181031
is_sum_positive_mf, session.bqclient, ignore_failures=False
10191032
)
10201033

10211034

1022-
def test_managed_function_df_where_series(session, dataset_id, scalars_dfs):
1035+
def test_managed_function_df_where_mask_series(session, dataset_id, scalars_dfs):
10231036
try:
10241037

10251038
# The return type has to be bool type for callable where condition.
@@ -1041,14 +1054,14 @@ def is_sum_positive_series(s):
10411054
pd_int64_df = scalars_pandas_df[int64_cols]
10421055
pd_int64_df_filtered = pd_int64_df.dropna()
10431056

1044-
# Use callable condition in dataframe.where method.
1057+
# Test callable condition in dataframe.where method.
10451058
bf_result = bf_int64_df_filtered.where(is_sum_positive_series).to_pandas()
10461059
pd_result = pd_int64_df_filtered.where(is_sum_positive_series)
10471060

10481061
# Ignore any dtype difference.
10491062
pandas.testing.assert_frame_equal(bf_result, pd_result, check_dtype=False)
10501063

1051-
# Make sure the read_gbq_function path works for this function.
1064+
# Make sure the read_gbq_function path works for dataframe.where method.
10521065
is_sum_positive_series_ref = session.read_gbq_function(
10531066
function_name=is_sum_positive_series_mf.bigframes_bigquery_function,
10541067
is_row_processor=True,
@@ -1070,6 +1083,19 @@ def func_for_other(x):
10701083
bf_result_gbq, pd_result_gbq, check_dtype=False
10711084
)
10721085

1086+
# Test callable condition in dataframe.mask method.
1087+
bf_result_gbq = bf_int64_df_filtered.mask(
1088+
is_sum_positive_series_ref, func_for_other
1089+
).to_pandas()
1090+
pd_result_gbq = pd_int64_df_filtered.mask(
1091+
is_sum_positive_series, func_for_other
1092+
)
1093+
1094+
# Ignore any dtype difference.
1095+
pandas.testing.assert_frame_equal(
1096+
bf_result_gbq, pd_result_gbq, check_dtype=False
1097+
)
1098+
10731099
finally:
10741100
# Clean up the gcp assets created for the managed function.
10751101
cleanup_function_assets(

tests/system/large/functions/test_remote_function.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2850,7 +2850,7 @@ def foo(x: int) -> int:
28502850

28512851

28522852
@pytest.mark.flaky(retries=2, delay=120)
2853-
def test_remote_function_df_where(session, dataset_id, scalars_dfs):
2853+
def test_remote_function_df_where_mask(session, dataset_id, scalars_dfs):
28542854
try:
28552855

28562856
# The return type has to be bool type for callable where condition.
@@ -2873,14 +2873,22 @@ def is_sum_positive(a, b):
28732873
pd_int64_df = scalars_pandas_df[int64_cols]
28742874
pd_int64_df_filtered = pd_int64_df.dropna()
28752875

2876-
# Use callable condition in dataframe.where method.
2876+
# Test callable condition in dataframe.where method.
28772877
bf_result = bf_int64_df_filtered.where(is_sum_positive_mf, 0).to_pandas()
28782878
# Pandas doesn't support such case, use following as workaround.
28792879
pd_result = pd_int64_df_filtered.where(pd_int64_df_filtered.sum(axis=1) > 0, 0)
28802880

28812881
# Ignore any dtype difference.
28822882
pandas.testing.assert_frame_equal(bf_result, pd_result, check_dtype=False)
28832883

2884+
# Test callable condition in dataframe.mask method.
2885+
bf_result = bf_int64_df_filtered.mask(is_sum_positive_mf, 0).to_pandas()
2886+
# Pandas doesn't support such case, use following as workaround.
2887+
pd_result = pd_int64_df_filtered.mask(pd_int64_df_filtered.sum(axis=1) > 0, 0)
2888+
2889+
# Ignore any dtype difference.
2890+
pandas.testing.assert_frame_equal(bf_result, pd_result, check_dtype=False)
2891+
28842892
finally:
28852893
# Clean up the gcp assets created for the remote function.
28862894
cleanup_function_assets(
@@ -2889,7 +2897,7 @@ def is_sum_positive(a, b):
28892897

28902898

28912899
@pytest.mark.flaky(retries=2, delay=120)
2892-
def test_remote_function_df_where_series(session, dataset_id, scalars_dfs):
2900+
def test_remote_function_df_where_mask_series(session, dataset_id, scalars_dfs):
28932901
try:
28942902

28952903
# The return type has to be bool type for callable where condition.
@@ -2916,7 +2924,7 @@ def is_sum_positive_series(s):
29162924
def func_for_other(x):
29172925
return -x
29182926

2919-
# Use callable condition in dataframe.where method.
2927+
# Test callable condition in dataframe.where method.
29202928
bf_result = bf_int64_df_filtered.where(
29212929
is_sum_positive_series, func_for_other
29222930
).to_pandas()
@@ -2925,6 +2933,15 @@ def func_for_other(x):
29252933
# Ignore any dtype difference.
29262934
pandas.testing.assert_frame_equal(bf_result, pd_result, check_dtype=False)
29272935

2936+
# Test callable condition in dataframe.mask method.
2937+
bf_result = bf_int64_df_filtered.mask(
2938+
is_sum_positive_series_mf, func_for_other
2939+
).to_pandas()
2940+
pd_result = pd_int64_df_filtered.mask(is_sum_positive_series, func_for_other)
2941+
2942+
# Ignore any dtype difference.
2943+
pandas.testing.assert_frame_equal(bf_result, pd_result, check_dtype=False)
2944+
29282945
finally:
29292946
# Clean up the gcp assets created for the remote function.
29302947
cleanup_function_assets(

tests/system/small/test_dataframe.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -406,6 +406,18 @@ def test_mask_series_cond(scalars_df_index, scalars_pandas_df_index):
406406
pandas.testing.assert_frame_equal(bf_result, pd_result)
407407

408408

409+
def test_mask_callable(scalars_df_index, scalars_pandas_df_index):
410+
def is_positive(x):
411+
return x > 0
412+
413+
bf_df = scalars_df_index[["int64_too", "int64_col", "float64_col"]]
414+
pd_df = scalars_pandas_df_index[["int64_too", "int64_col", "float64_col"]]
415+
bf_result = bf_df.mask(cond=is_positive, other=lambda x: x + 1).to_pandas()
416+
pd_result = pd_df.mask(cond=is_positive, other=lambda x: x + 1)
417+
418+
pandas.testing.assert_frame_equal(bf_result, pd_result)
419+
420+
409421
def test_where_multi_column(scalars_df_index, scalars_pandas_df_index):
410422
# Test when a dataframe has multi-columns.
411423
columns = ["int64_col", "float64_col"]

0 commit comments

Comments
 (0)