Skip to content

Commit 8689199

Browse files
authored
fix: Resolve the validation issue for other arg in dataframe where method (#2042)
1 parent 1a0f710 commit 8689199

File tree

4 files changed

+78
-3
lines changed

4 files changed

+78
-3
lines changed

bigframes/dataframe.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2877,9 +2877,6 @@ def _apply_callable(self, condition):
28772877
return condition
28782878

28792879
def where(self, cond, other=None):
2880-
if isinstance(other, bigframes.series.Series):
2881-
raise ValueError("Seires is not a supported replacement type!")
2882-
28832880
if self.columns.nlevels > 1:
28842881
raise NotImplementedError(
28852882
"The dataframe.where() method does not support multi-column."
@@ -2890,6 +2887,9 @@ def where(self, cond, other=None):
28902887
cond = self._apply_callable(cond)
28912888
other = self._apply_callable(other)
28922889

2890+
if isinstance(other, bigframes.series.Series):
2891+
raise ValueError("Seires is not a supported replacement type!")
2892+
28932893
aligned_block, (_, _) = self._block.join(cond._block, how="left")
28942894
# No left join is needed when 'other' is None or constant.
28952895
if isinstance(other, bigframes.dataframe.DataFrame):

tests/system/large/functions/test_managed_function.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1214,6 +1214,37 @@ def func_for_other(x):
12141214
)
12151215

12161216

1217+
def test_managed_function_df_where_other_issue(session, dataset_id, scalars_df_index):
1218+
try:
1219+
1220+
def the_sum(s):
1221+
return s["int64_col"] + s["int64_too"]
1222+
1223+
the_sum_mf = session.udf(
1224+
input_types=bigframes.series.Series,
1225+
output_type=int,
1226+
dataset=dataset_id,
1227+
name=prefixer.create_prefix(),
1228+
)(the_sum)
1229+
1230+
int64_cols = ["int64_col", "int64_too"]
1231+
1232+
bf_int64_df = scalars_df_index[int64_cols]
1233+
bf_int64_df_filtered = bf_int64_df.dropna()
1234+
1235+
with pytest.raises(
1236+
ValueError,
1237+
match="Seires is not a supported replacement type!",
1238+
):
1239+
# The execution of the callable other=the_sum_mf will return a
1240+
# Series, which is not a supported replacement type.
1241+
bf_int64_df_filtered.where(cond=bf_int64_df_filtered, other=the_sum_mf)
1242+
1243+
finally:
1244+
# Clean up the gcp assets created for the managed function.
1245+
cleanup_function_assets(the_sum_mf, session.bqclient, ignore_failures=False)
1246+
1247+
12171248
def test_managed_function_series_where_mask(session, dataset_id, scalars_dfs):
12181249
try:
12191250

tests/system/large/functions/test_remote_function.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3004,6 +3004,38 @@ def is_sum_positive(a, b):
30043004
)
30053005

30063006

3007+
@pytest.mark.flaky(retries=2, delay=120)
3008+
def test_remote_function_df_where_other_issue(session, dataset_id, scalars_df_index):
3009+
try:
3010+
3011+
def the_sum(a, b):
3012+
return a + b
3013+
3014+
the_sum_mf = session.remote_function(
3015+
input_types=[int, float],
3016+
output_type=float,
3017+
dataset=dataset_id,
3018+
reuse=False,
3019+
cloud_function_service_account="default",
3020+
)(the_sum)
3021+
3022+
int64_cols = ["int64_col", "float64_col"]
3023+
bf_int64_df = scalars_df_index[int64_cols]
3024+
bf_int64_df_filtered = bf_int64_df.dropna()
3025+
3026+
with pytest.raises(
3027+
ValueError,
3028+
match="Seires is not a supported replacement type!",
3029+
):
3030+
# The execution of the callable other=the_sum_mf will return a
3031+
# Series, which is not a supported replacement type.
3032+
bf_int64_df_filtered.where(cond=bf_int64_df > 100, other=the_sum_mf)
3033+
3034+
finally:
3035+
# Clean up the gcp assets created for the remote function.
3036+
cleanup_function_assets(the_sum_mf, session.bqclient, ignore_failures=False)
3037+
3038+
30073039
@pytest.mark.flaky(retries=2, delay=120)
30083040
def test_remote_function_df_where_mask_series(session, dataset_id, scalars_dfs):
30093041
try:

tests/system/small/test_dataframe.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -570,6 +570,18 @@ def func(x):
570570
pandas.testing.assert_frame_equal(bf_result, pd_result)
571571

572572

573+
def test_where_series_other(scalars_df_index):
574+
# When other is a series, throw an error.
575+
columns = ["int64_col", "float64_col"]
576+
dataframe_bf = scalars_df_index[columns]
577+
578+
with pytest.raises(
579+
ValueError,
580+
match="Seires is not a supported replacement type!",
581+
):
582+
dataframe_bf.where(dataframe_bf > 0, dataframe_bf["int64_col"])
583+
584+
573585
def test_drop_column(scalars_dfs):
574586
scalars_df, scalars_pandas_df = scalars_dfs
575587
col_name = "int64_col"

0 commit comments

Comments
 (0)