diff --git a/bigframes/dataframe.py b/bigframes/dataframe.py index 921893fb83..85b8245272 100644 --- a/bigframes/dataframe.py +++ b/bigframes/dataframe.py @@ -2828,6 +2828,19 @@ def itertuples( for item in df.itertuples(index=index, name=name): yield item + def _apply_callable(self, condition): + """Executes the possible callable condition as needed.""" + if callable(condition): + # When it's a bigframes function. + if hasattr(condition, "bigframes_bigquery_function"): + return self.apply(condition, axis=1) + + # When it's a plain Python function. + return condition(self) + + # When it's not a callable. + return condition + def where(self, cond, other=None): if isinstance(other, bigframes.series.Series): raise ValueError("Seires is not a supported replacement type!") @@ -2839,16 +2852,8 @@ def where(self, cond, other=None): # Execute it with the DataFrame when cond or/and other is callable. # It can be either a plain python function or remote/managed function. - if callable(cond): - if hasattr(cond, "bigframes_bigquery_function"): - cond = self.apply(cond, axis=1) - else: - cond = cond(self) - if callable(other): - if hasattr(other, "bigframes_bigquery_function"): - other = self.apply(other, axis=1) - else: - other = other(self) + cond = self._apply_callable(cond) + other = self._apply_callable(other) aligned_block, (_, _) = self._block.join(cond._block, how="left") # No left join is needed when 'other' is None or constant. @@ -2899,7 +2904,7 @@ def where(self, cond, other=None): return result def mask(self, cond, other=None): - return self.where(~cond, other=other) + return self.where(~self._apply_callable(cond), other=other) def dropna( self, diff --git a/tests/system/large/functions/test_managed_function.py b/tests/system/large/functions/test_managed_function.py index 6f5ef5b534..73335afa3c 100644 --- a/tests/system/large/functions/test_managed_function.py +++ b/tests/system/large/functions/test_managed_function.py @@ -965,7 +965,7 @@ def float_parser(row): ) -def test_managed_function_df_where(session, dataset_id, scalars_dfs): +def test_managed_function_df_where_mask(session, dataset_id, scalars_dfs): try: # The return type has to be bool type for callable where condition. @@ -987,7 +987,7 @@ def is_sum_positive(a, b): pd_int64_df = scalars_pandas_df[int64_cols] pd_int64_df_filtered = pd_int64_df.dropna() - # Use callable condition in dataframe.where method. + # Test callable condition in dataframe.where method. bf_result = bf_int64_df_filtered.where(is_sum_positive_mf).to_pandas() # Pandas doesn't support such case, use following as workaround. pd_result = pd_int64_df_filtered.where(pd_int64_df_filtered.sum(axis=1) > 0) @@ -995,7 +995,7 @@ def is_sum_positive(a, b): # Ignore any dtype difference. pandas.testing.assert_frame_equal(bf_result, pd_result, check_dtype=False) - # Make sure the read_gbq_function path works for this function. + # Make sure the read_gbq_function path works for dataframe.where method. is_sum_positive_ref = session.read_gbq_function( function_name=is_sum_positive_mf.bigframes_bigquery_function ) @@ -1012,6 +1012,19 @@ def is_sum_positive(a, b): bf_result_gbq, pd_result_gbq, check_dtype=False ) + # Test callable condition in dataframe.mask method. + bf_result_gbq = bf_int64_df_filtered.mask( + is_sum_positive_ref, -bf_int64_df_filtered + ).to_pandas() + pd_result_gbq = pd_int64_df_filtered.mask( + pd_int64_df_filtered.sum(axis=1) > 0, -pd_int64_df_filtered + ) + + # Ignore any dtype difference. + pandas.testing.assert_frame_equal( + bf_result_gbq, pd_result_gbq, check_dtype=False + ) + finally: # Clean up the gcp assets created for the managed function. cleanup_function_assets( @@ -1019,7 +1032,7 @@ def is_sum_positive(a, b): ) -def test_managed_function_df_where_series(session, dataset_id, scalars_dfs): +def test_managed_function_df_where_mask_series(session, dataset_id, scalars_dfs): try: # The return type has to be bool type for callable where condition. @@ -1041,14 +1054,14 @@ def is_sum_positive_series(s): pd_int64_df = scalars_pandas_df[int64_cols] pd_int64_df_filtered = pd_int64_df.dropna() - # Use callable condition in dataframe.where method. + # Test callable condition in dataframe.where method. bf_result = bf_int64_df_filtered.where(is_sum_positive_series).to_pandas() pd_result = pd_int64_df_filtered.where(is_sum_positive_series) # Ignore any dtype difference. pandas.testing.assert_frame_equal(bf_result, pd_result, check_dtype=False) - # Make sure the read_gbq_function path works for this function. + # Make sure the read_gbq_function path works for dataframe.where method. is_sum_positive_series_ref = session.read_gbq_function( function_name=is_sum_positive_series_mf.bigframes_bigquery_function, is_row_processor=True, @@ -1070,6 +1083,19 @@ def func_for_other(x): bf_result_gbq, pd_result_gbq, check_dtype=False ) + # Test callable condition in dataframe.mask method. + bf_result_gbq = bf_int64_df_filtered.mask( + is_sum_positive_series_ref, func_for_other + ).to_pandas() + pd_result_gbq = pd_int64_df_filtered.mask( + is_sum_positive_series, func_for_other + ) + + # Ignore any dtype difference. + pandas.testing.assert_frame_equal( + bf_result_gbq, pd_result_gbq, check_dtype=False + ) + finally: # Clean up the gcp assets created for the managed function. cleanup_function_assets( diff --git a/tests/system/large/functions/test_remote_function.py b/tests/system/large/functions/test_remote_function.py index cb61d3769c..3c453a52a4 100644 --- a/tests/system/large/functions/test_remote_function.py +++ b/tests/system/large/functions/test_remote_function.py @@ -2850,7 +2850,7 @@ def foo(x: int) -> int: @pytest.mark.flaky(retries=2, delay=120) -def test_remote_function_df_where(session, dataset_id, scalars_dfs): +def test_remote_function_df_where_mask(session, dataset_id, scalars_dfs): try: # The return type has to be bool type for callable where condition. @@ -2873,7 +2873,7 @@ def is_sum_positive(a, b): pd_int64_df = scalars_pandas_df[int64_cols] pd_int64_df_filtered = pd_int64_df.dropna() - # Use callable condition in dataframe.where method. + # Test callable condition in dataframe.where method. bf_result = bf_int64_df_filtered.where(is_sum_positive_mf, 0).to_pandas() # Pandas doesn't support such case, use following as workaround. pd_result = pd_int64_df_filtered.where(pd_int64_df_filtered.sum(axis=1) > 0, 0) @@ -2881,6 +2881,14 @@ def is_sum_positive(a, b): # Ignore any dtype difference. pandas.testing.assert_frame_equal(bf_result, pd_result, check_dtype=False) + # Test callable condition in dataframe.mask method. + bf_result = bf_int64_df_filtered.mask(is_sum_positive_mf, 0).to_pandas() + # Pandas doesn't support such case, use following as workaround. + pd_result = pd_int64_df_filtered.mask(pd_int64_df_filtered.sum(axis=1) > 0, 0) + + # Ignore any dtype difference. + pandas.testing.assert_frame_equal(bf_result, pd_result, check_dtype=False) + finally: # Clean up the gcp assets created for the remote function. cleanup_function_assets( @@ -2889,7 +2897,7 @@ def is_sum_positive(a, b): @pytest.mark.flaky(retries=2, delay=120) -def test_remote_function_df_where_series(session, dataset_id, scalars_dfs): +def test_remote_function_df_where_mask_series(session, dataset_id, scalars_dfs): try: # The return type has to be bool type for callable where condition. @@ -2916,7 +2924,7 @@ def is_sum_positive_series(s): def func_for_other(x): return -x - # Use callable condition in dataframe.where method. + # Test callable condition in dataframe.where method. bf_result = bf_int64_df_filtered.where( is_sum_positive_series, func_for_other ).to_pandas() @@ -2925,6 +2933,15 @@ def func_for_other(x): # Ignore any dtype difference. pandas.testing.assert_frame_equal(bf_result, pd_result, check_dtype=False) + # Test callable condition in dataframe.mask method. + bf_result = bf_int64_df_filtered.mask( + is_sum_positive_series_mf, func_for_other + ).to_pandas() + pd_result = pd_int64_df_filtered.mask(is_sum_positive_series, func_for_other) + + # Ignore any dtype difference. + pandas.testing.assert_frame_equal(bf_result, pd_result, check_dtype=False) + finally: # Clean up the gcp assets created for the remote function. cleanup_function_assets( diff --git a/tests/system/small/test_dataframe.py b/tests/system/small/test_dataframe.py index 8a570ade45..51f4674ba4 100644 --- a/tests/system/small/test_dataframe.py +++ b/tests/system/small/test_dataframe.py @@ -406,6 +406,18 @@ def test_mask_series_cond(scalars_df_index, scalars_pandas_df_index): pandas.testing.assert_frame_equal(bf_result, pd_result) +def test_mask_callable(scalars_df_index, scalars_pandas_df_index): + def is_positive(x): + return x > 0 + + bf_df = scalars_df_index[["int64_too", "int64_col", "float64_col"]] + pd_df = scalars_pandas_df_index[["int64_too", "int64_col", "float64_col"]] + bf_result = bf_df.mask(cond=is_positive, other=lambda x: x + 1).to_pandas() + pd_result = pd_df.mask(cond=is_positive, other=lambda x: x + 1) + + pandas.testing.assert_frame_equal(bf_result, pd_result) + + def test_where_multi_column(scalars_df_index, scalars_pandas_df_index): # Test when a dataframe has multi-columns. columns = ["int64_col", "float64_col"]