diff --git a/bigframes/series.py b/bigframes/series.py index 58bd47bff0..80952f38bc 100644 --- a/bigframes/series.py +++ b/bigframes/series.py @@ -2113,13 +2113,8 @@ def duplicated(self, keep: str = "first") -> Series: ) def mask(self, cond, other=None) -> Series: - if callable(cond): - if hasattr(cond, "bigframes_bigquery_function"): - cond = self.apply(cond) - else: - # For non-BigQuery function assume that it is applicable on Series - cond = self.apply(cond, by_row=False) - + cond = self._apply_callable(cond) + other = self._apply_callable(other) if not isinstance(cond, Series): raise TypeError( f"Only bigframes series condition is supported, received {type(cond).__name__}. " diff --git a/tests/system/large/functions/test_managed_function.py b/tests/system/large/functions/test_managed_function.py index 262f5f0fe2..43fb322567 100644 --- a/tests/system/large/functions/test_managed_function.py +++ b/tests/system/large/functions/test_managed_function.py @@ -1077,7 +1077,7 @@ def func_for_other(x): ) -def test_managed_function_series_where(session, dataset_id, scalars_dfs): +def test_managed_function_series_where_mask(session, dataset_id, scalars_dfs): try: # The return type has to be bool type for callable where condition. @@ -1098,8 +1098,8 @@ def _is_positive(s): pd_int64 = scalars_pandas["int64_col"] pd_int64_filtered = pd_int64.dropna() - # The cond is a callable (managed function) and the other is not a - # callable in series.where method. + # Test series.where method: the cond is a callable (managed function) + # and the other is not a callable. bf_result = bf_int64_filtered.where( cond=is_positive_mf, other=-bf_int64_filtered ).to_pandas() @@ -1108,6 +1108,16 @@ def _is_positive(s): # Ignore any dtype difference. pandas.testing.assert_series_equal(bf_result, pd_result, check_dtype=False) + # Test series.mask method: the cond is a callable (managed function) + # and the other is not a callable. + bf_result = bf_int64_filtered.mask( + cond=is_positive_mf, other=-bf_int64_filtered + ).to_pandas() + pd_result = pd_int64_filtered.mask(cond=_is_positive, other=-pd_int64_filtered) + + # Ignore any dtype difference. + pandas.testing.assert_series_equal(bf_result, pd_result, check_dtype=False) + finally: # Clean up the gcp assets created for the managed function. cleanup_function_assets(is_positive_mf, session.bqclient, ignore_failures=False) diff --git a/tests/system/large/functions/test_remote_function.py b/tests/system/large/functions/test_remote_function.py index 9e2c1e2c81..1c44b7e5fb 100644 --- a/tests/system/large/functions/test_remote_function.py +++ b/tests/system/large/functions/test_remote_function.py @@ -2933,7 +2933,7 @@ def func_for_other(x): @pytest.mark.flaky(retries=2, delay=120) -def test_remote_function_series_where(session, dataset_id, scalars_dfs): +def test_remote_function_series_where_mask(session, dataset_id, scalars_dfs): try: def _ten_times(x): @@ -2954,8 +2954,8 @@ def _ten_times(x): pd_int64 = scalars_pandas["float64_col"] pd_int64_filtered = pd_int64.dropna() - # The cond is not a callable and the other is a callable (remote - # function) in series.where method. + # Test series.where method: the cond is not a callable and the other is + # a callable (remote function). bf_result = bf_int64_filtered.where( cond=bf_int64_filtered < 0, other=ten_times_mf ).to_pandas() @@ -2966,6 +2966,16 @@ def _ten_times(x): # Ignore any dtype difference. pandas.testing.assert_series_equal(bf_result, pd_result, check_dtype=False) + # Test series.mask method: the cond is not a callable and the other is + # a callable (remote function). + bf_result = bf_int64_filtered.mask( + cond=bf_int64_filtered < 0, other=ten_times_mf + ).to_pandas() + pd_result = pd_int64_filtered.mask(cond=pd_int64_filtered < 0, other=_ten_times) + + # Ignore any dtype difference. + pandas.testing.assert_series_equal(bf_result, pd_result, check_dtype=False) + finally: # Clean up the gcp assets created for the remote function. cleanup_function_assets(ten_times_mf, session.bqclient, ignore_failures=False) diff --git a/tests/system/small/test_series.py b/tests/system/small/test_series.py index 60a3d73dd4..165e3b6df0 100644 --- a/tests/system/small/test_series.py +++ b/tests/system/small/test_series.py @@ -3603,6 +3603,26 @@ def test_mask_custom_value(scalars_dfs): assert_pandas_df_equal(bf_result, pd_result) +def test_mask_with_callable(scalars_df_index, scalars_pandas_df_index): + def _ten_times(x): + return x * 10 + + # Both cond and other are callable. + bf_result = ( + scalars_df_index["int64_col"] + .mask(cond=lambda x: x > 0, other=_ten_times) + .to_pandas() + ) + pd_result = scalars_pandas_df_index["int64_col"].mask( + cond=lambda x: x > 0, other=_ten_times + ) + + pd.testing.assert_series_equal( + bf_result, + pd_result, + ) + + @pytest.mark.parametrize( ("lambda_",), [