Skip to content

Commit a8d57d2

Browse files
authored
feat: Allow callable as a conditional or replacement input in DataFrame.where (#1971)
* feat: Allow callable as a conditional or replacement input in DataFrame.where() * fix lint
1 parent cd954ac commit a8d57d2

File tree

2 files changed

+50
-0
lines changed

2 files changed

+50
-0
lines changed

bigframes/dataframe.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2763,6 +2763,12 @@ def where(self, cond, other=None):
27632763
"The dataframe.where() method does not support multi-column."
27642764
)
27652765

2766+
# Execute it with the DataFrame when cond or/and other is callable.
2767+
if callable(cond):
2768+
cond = cond(self)
2769+
if callable(other):
2770+
other = other(self)
2771+
27662772
aligned_block, (_, _) = self._block.join(cond._block, how="left")
27672773
# No left join is needed when 'other' is None or constant.
27682774
if isinstance(other, bigframes.dataframe.DataFrame):

tests/system/small/test_dataframe.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -514,6 +514,50 @@ def test_where_dataframe_cond_dataframe_other(
514514
pandas.testing.assert_frame_equal(bf_result, pd_result)
515515

516516

517+
def test_where_callable_cond_constant_other(scalars_df_index, scalars_pandas_df_index):
518+
# Condition is callable, other is a constant.
519+
columns = ["int64_col", "float64_col"]
520+
dataframe_bf = scalars_df_index[columns]
521+
dataframe_pd = scalars_pandas_df_index[columns]
522+
523+
other = 10
524+
525+
bf_result = dataframe_bf.where(lambda x: x > 0, other).to_pandas()
526+
pd_result = dataframe_pd.where(lambda x: x > 0, other)
527+
pandas.testing.assert_frame_equal(bf_result, pd_result)
528+
529+
530+
def test_where_dataframe_cond_callable_other(scalars_df_index, scalars_pandas_df_index):
531+
# Condition is a dataframe, other is callable.
532+
columns = ["int64_col", "float64_col"]
533+
dataframe_bf = scalars_df_index[columns]
534+
dataframe_pd = scalars_pandas_df_index[columns]
535+
536+
cond_bf = dataframe_bf > 0
537+
cond_pd = dataframe_pd > 0
538+
539+
def func(x):
540+
return x * 2
541+
542+
bf_result = dataframe_bf.where(cond_bf, func).to_pandas()
543+
pd_result = dataframe_pd.where(cond_pd, func)
544+
pandas.testing.assert_frame_equal(bf_result, pd_result)
545+
546+
547+
def test_where_callable_cond_callable_other(scalars_df_index, scalars_pandas_df_index):
548+
# Condition is callable, other is callable too.
549+
columns = ["int64_col", "float64_col"]
550+
dataframe_bf = scalars_df_index[columns]
551+
dataframe_pd = scalars_pandas_df_index[columns]
552+
553+
def func(x):
554+
return x["int64_col"] > 0
555+
556+
bf_result = dataframe_bf.where(func, lambda x: x * 2).to_pandas()
557+
pd_result = dataframe_pd.where(func, lambda x: x * 2)
558+
pandas.testing.assert_frame_equal(bf_result, pd_result)
559+
560+
517561
def test_drop_column(scalars_dfs):
518562
scalars_df, scalars_pandas_df = scalars_dfs
519563
col_name = "int64_col"

0 commit comments

Comments
 (0)