diff --git a/bigframes/dataframe.py b/bigframes/dataframe.py index d618d13aa4..9ab47276fd 100644 --- a/bigframes/dataframe.py +++ b/bigframes/dataframe.py @@ -2877,9 +2877,6 @@ def _apply_callable(self, condition): return condition def where(self, cond, other=None): - if isinstance(other, bigframes.series.Series): - raise ValueError("Seires is not a supported replacement type!") - if self.columns.nlevels > 1: raise NotImplementedError( "The dataframe.where() method does not support multi-column." @@ -2890,6 +2887,9 @@ def where(self, cond, other=None): cond = self._apply_callable(cond) other = self._apply_callable(other) + if isinstance(other, bigframes.series.Series): + raise ValueError("Seires is not a supported replacement type!") + aligned_block, (_, _) = self._block.join(cond._block, how="left") # No left join is needed when 'other' is None or constant. if isinstance(other, bigframes.dataframe.DataFrame): diff --git a/tests/system/large/functions/test_managed_function.py b/tests/system/large/functions/test_managed_function.py index b0e44b648f..0a04480a78 100644 --- a/tests/system/large/functions/test_managed_function.py +++ b/tests/system/large/functions/test_managed_function.py @@ -1214,6 +1214,37 @@ def func_for_other(x): ) +def test_managed_function_df_where_other_issue(session, dataset_id, scalars_df_index): + try: + + def the_sum(s): + return s["int64_col"] + s["int64_too"] + + the_sum_mf = session.udf( + input_types=bigframes.series.Series, + output_type=int, + dataset=dataset_id, + name=prefixer.create_prefix(), + )(the_sum) + + int64_cols = ["int64_col", "int64_too"] + + bf_int64_df = scalars_df_index[int64_cols] + bf_int64_df_filtered = bf_int64_df.dropna() + + with pytest.raises( + ValueError, + match="Seires is not a supported replacement type!", + ): + # The execution of the callable other=the_sum_mf will return a + # Series, which is not a supported replacement type. + bf_int64_df_filtered.where(cond=bf_int64_df_filtered, other=the_sum_mf) + + finally: + # Clean up the gcp assets created for the managed function. + cleanup_function_assets(the_sum_mf, session.bqclient, ignore_failures=False) + + def test_managed_function_series_where_mask(session, dataset_id, scalars_dfs): try: diff --git a/tests/system/large/functions/test_remote_function.py b/tests/system/large/functions/test_remote_function.py index e6372d768b..f60786437f 100644 --- a/tests/system/large/functions/test_remote_function.py +++ b/tests/system/large/functions/test_remote_function.py @@ -3004,6 +3004,38 @@ def is_sum_positive(a, b): ) +@pytest.mark.flaky(retries=2, delay=120) +def test_remote_function_df_where_other_issue(session, dataset_id, scalars_df_index): + try: + + def the_sum(a, b): + return a + b + + the_sum_mf = session.remote_function( + input_types=[int, float], + output_type=float, + dataset=dataset_id, + reuse=False, + cloud_function_service_account="default", + )(the_sum) + + int64_cols = ["int64_col", "float64_col"] + bf_int64_df = scalars_df_index[int64_cols] + bf_int64_df_filtered = bf_int64_df.dropna() + + with pytest.raises( + ValueError, + match="Seires is not a supported replacement type!", + ): + # The execution of the callable other=the_sum_mf will return a + # Series, which is not a supported replacement type. + bf_int64_df_filtered.where(cond=bf_int64_df > 100, other=the_sum_mf) + + finally: + # Clean up the gcp assets created for the remote function. + cleanup_function_assets(the_sum_mf, session.bqclient, ignore_failures=False) + + @pytest.mark.flaky(retries=2, delay=120) def test_remote_function_df_where_mask_series(session, dataset_id, scalars_dfs): try: diff --git a/tests/system/small/test_dataframe.py b/tests/system/small/test_dataframe.py index c7f9627531..dce0a649f6 100644 --- a/tests/system/small/test_dataframe.py +++ b/tests/system/small/test_dataframe.py @@ -570,6 +570,18 @@ def func(x): pandas.testing.assert_frame_equal(bf_result, pd_result) +def test_where_series_other(scalars_df_index): + # When other is a series, throw an error. + columns = ["int64_col", "float64_col"] + dataframe_bf = scalars_df_index[columns] + + with pytest.raises( + ValueError, + match="Seires is not a supported replacement type!", + ): + dataframe_bf.where(dataframe_bf > 0, dataframe_bf["int64_col"]) + + def test_drop_column(scalars_dfs): scalars_df, scalars_pandas_df = scalars_dfs col_name = "int64_col"