Skip to content

feat: Support callable bigframes function for dataframe where #1990

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Aug 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions bigframes/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -2797,10 +2797,17 @@ def where(self, cond, other=None):
)

# Execute it with the DataFrame when cond or/and other is callable.
# It can be either a plain python function or remote/managed function.
if callable(cond):
cond = cond(self)
if hasattr(cond, "bigframes_bigquery_function"):
cond = self.apply(cond, axis=1)
else:
cond = cond(self)
if callable(other):
other = other(self)
if hasattr(other, "bigframes_bigquery_function"):
other = self.apply(other, axis=1)
else:
other = other(self)

aligned_block, (_, _) = self._block.join(cond._block, how="left")
# No left join is needed when 'other' is None or constant.
Expand All @@ -2813,7 +2820,7 @@ def where(self, cond, other=None):
labels = aligned_block.column_labels[:self_len]
self_col = {x: ex.deref(y) for x, y in zip(labels, ids)}

if isinstance(cond, bigframes.series.Series) and cond.name in self_col:
if isinstance(cond, bigframes.series.Series):
# This is when 'cond' is a valid series.
y = aligned_block.value_columns[self_len]
cond_col = {x: ex.deref(y) for x in self_col.keys()}
Expand Down
112 changes: 112 additions & 0 deletions tests/system/large/functions/test_managed_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -963,3 +963,115 @@ def float_parser(row):
cleanup_function_assets(
float_parser_mf, session.bqclient, ignore_failures=False
)


def test_managed_function_df_where(session, dataset_id, scalars_dfs):
try:

# The return type has to be bool type for callable where condition.
def is_sum_positive(a, b):
return a + b > 0

is_sum_positive_mf = session.udf(
input_types=[int, int],
output_type=bool,
dataset=dataset_id,
name=prefixer.create_prefix(),
)(is_sum_positive)

scalars_df, scalars_pandas_df = scalars_dfs
int64_cols = ["int64_col", "int64_too"]

bf_int64_df = scalars_df[int64_cols]
bf_int64_df_filtered = bf_int64_df.dropna()
pd_int64_df = scalars_pandas_df[int64_cols]
pd_int64_df_filtered = pd_int64_df.dropna()

# Use callable condition in dataframe.where method.
bf_result = bf_int64_df_filtered.where(is_sum_positive_mf).to_pandas()
# Pandas doesn't support such case, use following as workaround.
pd_result = pd_int64_df_filtered.where(pd_int64_df_filtered.sum(axis=1) > 0)

# Ignore any dtype difference.
pandas.testing.assert_frame_equal(bf_result, pd_result, check_dtype=False)

# Make sure the read_gbq_function path works for this function.
is_sum_positive_ref = session.read_gbq_function(
function_name=is_sum_positive_mf.bigframes_bigquery_function
)

bf_result_gbq = bf_int64_df_filtered.where(
is_sum_positive_ref, -bf_int64_df_filtered
).to_pandas()
pd_result_gbq = pd_int64_df_filtered.where(
pd_int64_df_filtered.sum(axis=1) > 0, -pd_int64_df_filtered
)

# Ignore any dtype difference.
pandas.testing.assert_frame_equal(
bf_result_gbq, pd_result_gbq, check_dtype=False
)

finally:
# Clean up the gcp assets created for the managed function.
cleanup_function_assets(
is_sum_positive_mf, session.bqclient, ignore_failures=False
)


def test_managed_function_df_where_series(session, dataset_id, scalars_dfs):
try:

# The return type has to be bool type for callable where condition.
def is_sum_positive_series(s):
return s["int64_col"] + s["int64_too"] > 0

is_sum_positive_series_mf = session.udf(
input_types=bigframes.series.Series,
output_type=bool,
dataset=dataset_id,
name=prefixer.create_prefix(),
)(is_sum_positive_series)

scalars_df, scalars_pandas_df = scalars_dfs
int64_cols = ["int64_col", "int64_too"]

bf_int64_df = scalars_df[int64_cols]
bf_int64_df_filtered = bf_int64_df.dropna()
pd_int64_df = scalars_pandas_df[int64_cols]
pd_int64_df_filtered = pd_int64_df.dropna()

# Use callable condition in dataframe.where method.
bf_result = bf_int64_df_filtered.where(is_sum_positive_series).to_pandas()
pd_result = pd_int64_df_filtered.where(is_sum_positive_series)

# Ignore any dtype difference.
pandas.testing.assert_frame_equal(bf_result, pd_result, check_dtype=False)

# Make sure the read_gbq_function path works for this function.
is_sum_positive_series_ref = session.read_gbq_function(
function_name=is_sum_positive_series_mf.bigframes_bigquery_function,
is_row_processor=True,
)

# This is for callable `other` arg in dataframe.where method.
def func_for_other(x):
return -x

bf_result_gbq = bf_int64_df_filtered.where(
is_sum_positive_series_ref, func_for_other
).to_pandas()
pd_result_gbq = pd_int64_df_filtered.where(
is_sum_positive_series, func_for_other
)

# Ignore any dtype difference.
pandas.testing.assert_frame_equal(
bf_result_gbq, pd_result_gbq, check_dtype=False
)

finally:
# Clean up the gcp assets created for the managed function.
cleanup_function_assets(
is_sum_positive_series_mf, session.bqclient, ignore_failures=False
)
83 changes: 83 additions & 0 deletions tests/system/large/functions/test_remote_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -2847,3 +2847,86 @@ def foo(x: int) -> int:
finally:
# clean up the gcp assets created for the remote function
cleanup_function_assets(foo, session.bqclient, session.cloudfunctionsclient)


@pytest.mark.flaky(retries=2, delay=120)
def test_remote_function_df_where(session, dataset_id, scalars_dfs):
try:

# The return type has to be bool type for callable where condition.
def is_sum_positive(a, b):
return a + b > 0

is_sum_positive_mf = session.remote_function(
input_types=[int, int],
output_type=bool,
dataset=dataset_id,
reuse=False,
cloud_function_service_account="default",
)(is_sum_positive)

scalars_df, scalars_pandas_df = scalars_dfs
int64_cols = ["int64_col", "int64_too"]

bf_int64_df = scalars_df[int64_cols]
bf_int64_df_filtered = bf_int64_df.dropna()
pd_int64_df = scalars_pandas_df[int64_cols]
pd_int64_df_filtered = pd_int64_df.dropna()

# Use callable condition in dataframe.where method.
bf_result = bf_int64_df_filtered.where(is_sum_positive_mf, 0).to_pandas()
# Pandas doesn't support such case, use following as workaround.
pd_result = pd_int64_df_filtered.where(pd_int64_df_filtered.sum(axis=1) > 0, 0)

# Ignore any dtype difference.
pandas.testing.assert_frame_equal(bf_result, pd_result, check_dtype=False)

finally:
# Clean up the gcp assets created for the remote function.
cleanup_function_assets(
is_sum_positive_mf, session.bqclient, ignore_failures=False
)


@pytest.mark.flaky(retries=2, delay=120)
def test_remote_function_df_where_series(session, dataset_id, scalars_dfs):
try:

# The return type has to be bool type for callable where condition.
def is_sum_positive_series(s):
return s["int64_col"] + s["int64_too"] > 0

is_sum_positive_series_mf = session.remote_function(
input_types=bigframes.series.Series,
output_type=bool,
dataset=dataset_id,
reuse=False,
cloud_function_service_account="default",
)(is_sum_positive_series)

scalars_df, scalars_pandas_df = scalars_dfs
int64_cols = ["int64_col", "int64_too"]

bf_int64_df = scalars_df[int64_cols]
bf_int64_df_filtered = bf_int64_df.dropna()
pd_int64_df = scalars_pandas_df[int64_cols]
pd_int64_df_filtered = pd_int64_df.dropna()

# This is for callable `other` arg in dataframe.where method.
def func_for_other(x):
return -x

# Use callable condition in dataframe.where method.
bf_result = bf_int64_df_filtered.where(
is_sum_positive_series, func_for_other
).to_pandas()
pd_result = pd_int64_df_filtered.where(is_sum_positive_series, func_for_other)

# Ignore any dtype difference.
pandas.testing.assert_frame_equal(bf_result, pd_result, check_dtype=False)

finally:
# Clean up the gcp assets created for the remote function.
cleanup_function_assets(
is_sum_positive_series_mf, session.bqclient, ignore_failures=False
)