Skip to content

Commit 5ac32eb

Browse files
authored
feat: Support callable for series mask method (#2014)
1 parent d442f41 commit 5ac32eb

File tree

4 files changed

+48
-13
lines changed

4 files changed

+48
-13
lines changed

bigframes/series.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2113,13 +2113,8 @@ def duplicated(self, keep: str = "first") -> Series:
21132113
)
21142114

21152115
def mask(self, cond, other=None) -> Series:
2116-
if callable(cond):
2117-
if hasattr(cond, "bigframes_bigquery_function"):
2118-
cond = self.apply(cond)
2119-
else:
2120-
# For non-BigQuery function assume that it is applicable on Series
2121-
cond = self.apply(cond, by_row=False)
2122-
2116+
cond = self._apply_callable(cond)
2117+
other = self._apply_callable(other)
21232118
if not isinstance(cond, Series):
21242119
raise TypeError(
21252120
f"Only bigframes series condition is supported, received {type(cond).__name__}. "

tests/system/large/functions/test_managed_function.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1077,7 +1077,7 @@ def func_for_other(x):
10771077
)
10781078

10791079

1080-
def test_managed_function_series_where(session, dataset_id, scalars_dfs):
1080+
def test_managed_function_series_where_mask(session, dataset_id, scalars_dfs):
10811081
try:
10821082

10831083
# The return type has to be bool type for callable where condition.
@@ -1098,8 +1098,8 @@ def _is_positive(s):
10981098
pd_int64 = scalars_pandas["int64_col"]
10991099
pd_int64_filtered = pd_int64.dropna()
11001100

1101-
# The cond is a callable (managed function) and the other is not a
1102-
# callable in series.where method.
1101+
# Test series.where method: the cond is a callable (managed function)
1102+
# and the other is not a callable.
11031103
bf_result = bf_int64_filtered.where(
11041104
cond=is_positive_mf, other=-bf_int64_filtered
11051105
).to_pandas()
@@ -1108,6 +1108,16 @@ def _is_positive(s):
11081108
# Ignore any dtype difference.
11091109
pandas.testing.assert_series_equal(bf_result, pd_result, check_dtype=False)
11101110

1111+
# Test series.mask method: the cond is a callable (managed function)
1112+
# and the other is not a callable.
1113+
bf_result = bf_int64_filtered.mask(
1114+
cond=is_positive_mf, other=-bf_int64_filtered
1115+
).to_pandas()
1116+
pd_result = pd_int64_filtered.mask(cond=_is_positive, other=-pd_int64_filtered)
1117+
1118+
# Ignore any dtype difference.
1119+
pandas.testing.assert_series_equal(bf_result, pd_result, check_dtype=False)
1120+
11111121
finally:
11121122
# Clean up the gcp assets created for the managed function.
11131123
cleanup_function_assets(is_positive_mf, session.bqclient, ignore_failures=False)

tests/system/large/functions/test_remote_function.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2933,7 +2933,7 @@ def func_for_other(x):
29332933

29342934

29352935
@pytest.mark.flaky(retries=2, delay=120)
2936-
def test_remote_function_series_where(session, dataset_id, scalars_dfs):
2936+
def test_remote_function_series_where_mask(session, dataset_id, scalars_dfs):
29372937
try:
29382938

29392939
def _ten_times(x):
@@ -2954,8 +2954,8 @@ def _ten_times(x):
29542954
pd_int64 = scalars_pandas["float64_col"]
29552955
pd_int64_filtered = pd_int64.dropna()
29562956

2957-
# The cond is not a callable and the other is a callable (remote
2958-
# function) in series.where method.
2957+
# Test series.where method: the cond is not a callable and the other is
2958+
# a callable (remote function).
29592959
bf_result = bf_int64_filtered.where(
29602960
cond=bf_int64_filtered < 0, other=ten_times_mf
29612961
).to_pandas()
@@ -2966,6 +2966,16 @@ def _ten_times(x):
29662966
# Ignore any dtype difference.
29672967
pandas.testing.assert_series_equal(bf_result, pd_result, check_dtype=False)
29682968

2969+
# Test series.mask method: the cond is not a callable and the other is
2970+
# a callable (remote function).
2971+
bf_result = bf_int64_filtered.mask(
2972+
cond=bf_int64_filtered < 0, other=ten_times_mf
2973+
).to_pandas()
2974+
pd_result = pd_int64_filtered.mask(cond=pd_int64_filtered < 0, other=_ten_times)
2975+
2976+
# Ignore any dtype difference.
2977+
pandas.testing.assert_series_equal(bf_result, pd_result, check_dtype=False)
2978+
29692979
finally:
29702980
# Clean up the gcp assets created for the remote function.
29712981
cleanup_function_assets(ten_times_mf, session.bqclient, ignore_failures=False)

tests/system/small/test_series.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3603,6 +3603,26 @@ def test_mask_custom_value(scalars_dfs):
36033603
assert_pandas_df_equal(bf_result, pd_result)
36043604

36053605

3606+
def test_mask_with_callable(scalars_df_index, scalars_pandas_df_index):
3607+
def _ten_times(x):
3608+
return x * 10
3609+
3610+
# Both cond and other are callable.
3611+
bf_result = (
3612+
scalars_df_index["int64_col"]
3613+
.mask(cond=lambda x: x > 0, other=_ten_times)
3614+
.to_pandas()
3615+
)
3616+
pd_result = scalars_pandas_df_index["int64_col"].mask(
3617+
cond=lambda x: x > 0, other=_ten_times
3618+
)
3619+
3620+
pd.testing.assert_series_equal(
3621+
bf_result,
3622+
pd_result,
3623+
)
3624+
3625+
36063626
@pytest.mark.parametrize(
36073627
("lambda_",),
36083628
[

0 commit comments

Comments
 (0)