diff --git a/bigframes/dataframe.py b/bigframes/dataframe.py index bcad00830d..c58cbaba6a 100644 --- a/bigframes/dataframe.py +++ b/bigframes/dataframe.py @@ -2797,10 +2797,17 @@ def where(self, cond, other=None): ) # Execute it with the DataFrame when cond or/and other is callable. + # It can be either a plain python function or remote/managed function. if callable(cond): - cond = cond(self) + if hasattr(cond, "bigframes_bigquery_function"): + cond = self.apply(cond, axis=1) + else: + cond = cond(self) if callable(other): - other = other(self) + if hasattr(other, "bigframes_bigquery_function"): + other = self.apply(other, axis=1) + else: + other = other(self) aligned_block, (_, _) = self._block.join(cond._block, how="left") # No left join is needed when 'other' is None or constant. @@ -2813,7 +2820,7 @@ def where(self, cond, other=None): labels = aligned_block.column_labels[:self_len] self_col = {x: ex.deref(y) for x, y in zip(labels, ids)} - if isinstance(cond, bigframes.series.Series) and cond.name in self_col: + if isinstance(cond, bigframes.series.Series): # This is when 'cond' is a valid series. y = aligned_block.value_columns[self_len] cond_col = {x: ex.deref(y) for x in self_col.keys()} diff --git a/tests/system/large/functions/test_managed_function.py b/tests/system/large/functions/test_managed_function.py index 5349529f1d..209e4df1e3 100644 --- a/tests/system/large/functions/test_managed_function.py +++ b/tests/system/large/functions/test_managed_function.py @@ -963,3 +963,115 @@ def float_parser(row): cleanup_function_assets( float_parser_mf, session.bqclient, ignore_failures=False ) + + +def test_managed_function_df_where(session, dataset_id, scalars_dfs): + try: + + # The return type has to be bool type for callable where condition. + def is_sum_positive(a, b): + return a + b > 0 + + is_sum_positive_mf = session.udf( + input_types=[int, int], + output_type=bool, + dataset=dataset_id, + name=prefixer.create_prefix(), + )(is_sum_positive) + + scalars_df, scalars_pandas_df = scalars_dfs + int64_cols = ["int64_col", "int64_too"] + + bf_int64_df = scalars_df[int64_cols] + bf_int64_df_filtered = bf_int64_df.dropna() + pd_int64_df = scalars_pandas_df[int64_cols] + pd_int64_df_filtered = pd_int64_df.dropna() + + # Use callable condition in dataframe.where method. + bf_result = bf_int64_df_filtered.where(is_sum_positive_mf).to_pandas() + # Pandas doesn't support such case, use following as workaround. + pd_result = pd_int64_df_filtered.where(pd_int64_df_filtered.sum(axis=1) > 0) + + # Ignore any dtype difference. + pandas.testing.assert_frame_equal(bf_result, pd_result, check_dtype=False) + + # Make sure the read_gbq_function path works for this function. + is_sum_positive_ref = session.read_gbq_function( + function_name=is_sum_positive_mf.bigframes_bigquery_function + ) + + bf_result_gbq = bf_int64_df_filtered.where( + is_sum_positive_ref, -bf_int64_df_filtered + ).to_pandas() + pd_result_gbq = pd_int64_df_filtered.where( + pd_int64_df_filtered.sum(axis=1) > 0, -pd_int64_df_filtered + ) + + # Ignore any dtype difference. + pandas.testing.assert_frame_equal( + bf_result_gbq, pd_result_gbq, check_dtype=False + ) + + finally: + # Clean up the gcp assets created for the managed function. + cleanup_function_assets( + is_sum_positive_mf, session.bqclient, ignore_failures=False + ) + + +def test_managed_function_df_where_series(session, dataset_id, scalars_dfs): + try: + + # The return type has to be bool type for callable where condition. + def is_sum_positive_series(s): + return s["int64_col"] + s["int64_too"] > 0 + + is_sum_positive_series_mf = session.udf( + input_types=bigframes.series.Series, + output_type=bool, + dataset=dataset_id, + name=prefixer.create_prefix(), + )(is_sum_positive_series) + + scalars_df, scalars_pandas_df = scalars_dfs + int64_cols = ["int64_col", "int64_too"] + + bf_int64_df = scalars_df[int64_cols] + bf_int64_df_filtered = bf_int64_df.dropna() + pd_int64_df = scalars_pandas_df[int64_cols] + pd_int64_df_filtered = pd_int64_df.dropna() + + # Use callable condition in dataframe.where method. + bf_result = bf_int64_df_filtered.where(is_sum_positive_series).to_pandas() + pd_result = pd_int64_df_filtered.where(is_sum_positive_series) + + # Ignore any dtype difference. + pandas.testing.assert_frame_equal(bf_result, pd_result, check_dtype=False) + + # Make sure the read_gbq_function path works for this function. + is_sum_positive_series_ref = session.read_gbq_function( + function_name=is_sum_positive_series_mf.bigframes_bigquery_function, + is_row_processor=True, + ) + + # This is for callable `other` arg in dataframe.where method. + def func_for_other(x): + return -x + + bf_result_gbq = bf_int64_df_filtered.where( + is_sum_positive_series_ref, func_for_other + ).to_pandas() + pd_result_gbq = pd_int64_df_filtered.where( + is_sum_positive_series, func_for_other + ) + + # Ignore any dtype difference. + pandas.testing.assert_frame_equal( + bf_result_gbq, pd_result_gbq, check_dtype=False + ) + + finally: + # Clean up the gcp assets created for the managed function. + cleanup_function_assets( + is_sum_positive_series_mf, session.bqclient, ignore_failures=False + ) diff --git a/tests/system/large/functions/test_remote_function.py b/tests/system/large/functions/test_remote_function.py index a93435d11a..558b292c49 100644 --- a/tests/system/large/functions/test_remote_function.py +++ b/tests/system/large/functions/test_remote_function.py @@ -2847,3 +2847,86 @@ def foo(x: int) -> int: finally: # clean up the gcp assets created for the remote function cleanup_function_assets(foo, session.bqclient, session.cloudfunctionsclient) + + +@pytest.mark.flaky(retries=2, delay=120) +def test_remote_function_df_where(session, dataset_id, scalars_dfs): + try: + + # The return type has to be bool type for callable where condition. + def is_sum_positive(a, b): + return a + b > 0 + + is_sum_positive_mf = session.remote_function( + input_types=[int, int], + output_type=bool, + dataset=dataset_id, + reuse=False, + cloud_function_service_account="default", + )(is_sum_positive) + + scalars_df, scalars_pandas_df = scalars_dfs + int64_cols = ["int64_col", "int64_too"] + + bf_int64_df = scalars_df[int64_cols] + bf_int64_df_filtered = bf_int64_df.dropna() + pd_int64_df = scalars_pandas_df[int64_cols] + pd_int64_df_filtered = pd_int64_df.dropna() + + # Use callable condition in dataframe.where method. + bf_result = bf_int64_df_filtered.where(is_sum_positive_mf, 0).to_pandas() + # Pandas doesn't support such case, use following as workaround. + pd_result = pd_int64_df_filtered.where(pd_int64_df_filtered.sum(axis=1) > 0, 0) + + # Ignore any dtype difference. + pandas.testing.assert_frame_equal(bf_result, pd_result, check_dtype=False) + + finally: + # Clean up the gcp assets created for the remote function. + cleanup_function_assets( + is_sum_positive_mf, session.bqclient, ignore_failures=False + ) + + +@pytest.mark.flaky(retries=2, delay=120) +def test_remote_function_df_where_series(session, dataset_id, scalars_dfs): + try: + + # The return type has to be bool type for callable where condition. + def is_sum_positive_series(s): + return s["int64_col"] + s["int64_too"] > 0 + + is_sum_positive_series_mf = session.remote_function( + input_types=bigframes.series.Series, + output_type=bool, + dataset=dataset_id, + reuse=False, + cloud_function_service_account="default", + )(is_sum_positive_series) + + scalars_df, scalars_pandas_df = scalars_dfs + int64_cols = ["int64_col", "int64_too"] + + bf_int64_df = scalars_df[int64_cols] + bf_int64_df_filtered = bf_int64_df.dropna() + pd_int64_df = scalars_pandas_df[int64_cols] + pd_int64_df_filtered = pd_int64_df.dropna() + + # This is for callable `other` arg in dataframe.where method. + def func_for_other(x): + return -x + + # Use callable condition in dataframe.where method. + bf_result = bf_int64_df_filtered.where( + is_sum_positive_series, func_for_other + ).to_pandas() + pd_result = pd_int64_df_filtered.where(is_sum_positive_series, func_for_other) + + # Ignore any dtype difference. + pandas.testing.assert_frame_equal(bf_result, pd_result, check_dtype=False) + + finally: + # Clean up the gcp assets created for the remote function. + cleanup_function_assets( + is_sum_positive_series_mf, session.bqclient, ignore_failures=False + )