Skip to content

Commit f7f6a59

Browse files
committed
feat: Support callable for series mask method
1 parent 26df6e6 commit f7f6a59

File tree

4 files changed

+48
-13
lines changed

4 files changed

+48
-13
lines changed

bigframes/series.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2107,13 +2107,8 @@ def duplicated(self, keep: str = "first") -> Series:
21072107
)
21082108

21092109
def mask(self, cond, other=None) -> Series:
2110-
if callable(cond):
2111-
if hasattr(cond, "bigframes_bigquery_function"):
2112-
cond = self.apply(cond)
2113-
else:
2114-
# For non-BigQuery function assume that it is applicable on Series
2115-
cond = self.apply(cond, by_row=False)
2116-
2110+
cond = self._apply_callable(cond)
2111+
other = self._apply_callable(other)
21172112
if not isinstance(cond, Series):
21182113
raise TypeError(
21192114
f"Only bigframes series condition is supported, received {type(cond).__name__}. "

tests/system/large/functions/test_managed_function.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1077,7 +1077,7 @@ def func_for_other(x):
10771077
)
10781078

10791079

1080-
def test_managed_function_series_where(session, dataset_id, scalars_dfs):
1080+
def test_managed_function_series_where_mask(session, dataset_id, scalars_dfs):
10811081
try:
10821082

10831083
# The return type has to be bool type for callable where condition.
@@ -1098,8 +1098,8 @@ def _is_positive(s):
10981098
pd_int64 = scalars_pandas["int64_col"]
10991099
pd_int64_filtered = pd_int64.dropna()
11001100

1101-
# The cond is a callable (managed function) and the other is not a
1102-
# callable in series.where method.
1101+
# Test series.where method: the cond is a callable (managed function)
1102+
# and the other is not a callable.
11031103
bf_result = bf_int64_filtered.where(
11041104
cond=is_positive_mf, other=-bf_int64_filtered
11051105
).to_pandas()
@@ -1108,6 +1108,16 @@ def _is_positive(s):
11081108
# Ignore any dtype difference.
11091109
pandas.testing.assert_series_equal(bf_result, pd_result, check_dtype=False)
11101110

1111+
# Test series.mask method: the cond is a callable (managed function)
1112+
# and the other is not a callable.
1113+
bf_result = bf_int64_filtered.mask(
1114+
cond=is_positive_mf, other=-bf_int64_filtered
1115+
).to_pandas()
1116+
pd_result = pd_int64_filtered.mask(cond=_is_positive, other=-pd_int64_filtered)
1117+
1118+
# Ignore any dtype difference.
1119+
pandas.testing.assert_series_equal(bf_result, pd_result, check_dtype=False)
1120+
11111121
finally:
11121122
# Clean up the gcp assets created for the managed function.
11131123
cleanup_function_assets(is_positive_mf, session.bqclient, ignore_failures=False)

tests/system/large/functions/test_remote_function.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2933,7 +2933,7 @@ def func_for_other(x):
29332933

29342934

29352935
@pytest.mark.flaky(retries=2, delay=120)
2936-
def test_remote_function_series_where(session, dataset_id, scalars_dfs):
2936+
def test_remote_function_series_where_mask(session, dataset_id, scalars_dfs):
29372937
try:
29382938

29392939
def _ten_times(x):
@@ -2954,8 +2954,8 @@ def _ten_times(x):
29542954
pd_int64 = scalars_pandas["float64_col"]
29552955
pd_int64_filtered = pd_int64.dropna()
29562956

2957-
# The cond is not a callable and the other is a callable (remote
2958-
# function) in series.where method.
2957+
# Test series.where method: the cond is not a callable and the other is
2958+
# a callable (remote function).
29592959
bf_result = bf_int64_filtered.where(
29602960
cond=bf_int64_filtered < 0, other=ten_times_mf
29612961
).to_pandas()
@@ -2966,6 +2966,16 @@ def _ten_times(x):
29662966
# Ignore any dtype difference.
29672967
pandas.testing.assert_series_equal(bf_result, pd_result, check_dtype=False)
29682968

2969+
# Test series.mask method: the cond is not a callable and the other is
2970+
# a callable (remote function).
2971+
bf_result = bf_int64_filtered.mask(
2972+
cond=bf_int64_filtered < 0, other=ten_times_mf
2973+
).to_pandas()
2974+
pd_result = pd_int64_filtered.mask(cond=pd_int64_filtered < 0, other=_ten_times)
2975+
2976+
# Ignore any dtype difference.
2977+
pandas.testing.assert_series_equal(bf_result, pd_result, check_dtype=False)
2978+
29692979
finally:
29702980
# Clean up the gcp assets created for the remote function.
29712981
cleanup_function_assets(ten_times_mf, session.bqclient, ignore_failures=False)

tests/system/small/test_series.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3577,6 +3577,26 @@ def test_mask_custom_value(scalars_dfs):
35773577
assert_pandas_df_equal(bf_result, pd_result)
35783578

35793579

3580+
def test_mask_with_callable(scalars_df_index, scalars_pandas_df_index):
3581+
def _ten_times(x):
3582+
return x * 10
3583+
3584+
# Both cond and other are callable.
3585+
bf_result = (
3586+
scalars_df_index["int64_col"]
3587+
.mask(cond=lambda x: x > 0, other=_ten_times)
3588+
.to_pandas()
3589+
)
3590+
pd_result = scalars_pandas_df_index["int64_col"].mask(
3591+
cond=lambda x: x > 0, other=_ten_times
3592+
)
3593+
3594+
pd.testing.assert_series_equal(
3595+
bf_result,
3596+
pd_result,
3597+
)
3598+
3599+
35803600
@pytest.mark.parametrize(
35813601
("lambda_",),
35823602
[

0 commit comments

Comments
 (0)