Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 16 additions & 11 deletions bigframes/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -2828,6 +2828,19 @@ def itertuples(
for item in df.itertuples(index=index, name=name):
yield item

def _apply_callable(self, condition):
"""Executes the possible callable condition as needed."""
if callable(condition):
# When it's a bigframes function.
if hasattr(condition, "bigframes_bigquery_function"):
return self.apply(condition, axis=1)

# When it's a plain Python function.
return condition(self)

# When it's not a callable.
return condition

def where(self, cond, other=None):
if isinstance(other, bigframes.series.Series):
raise ValueError("Seires is not a supported replacement type!")
Expand All @@ -2839,16 +2852,8 @@ def where(self, cond, other=None):

# Execute it with the DataFrame when cond or/and other is callable.
# It can be either a plain python function or remote/managed function.
if callable(cond):
if hasattr(cond, "bigframes_bigquery_function"):
cond = self.apply(cond, axis=1)
else:
cond = cond(self)
if callable(other):
if hasattr(other, "bigframes_bigquery_function"):
other = self.apply(other, axis=1)
else:
other = other(self)
cond = self._apply_callable(cond)
other = self._apply_callable(other)

aligned_block, (_, _) = self._block.join(cond._block, how="left")
# No left join is needed when 'other' is None or constant.
Expand Down Expand Up @@ -2899,7 +2904,7 @@ def where(self, cond, other=None):
return result

def mask(self, cond, other=None):
return self.where(~cond, other=other)
return self.where(~self._apply_callable(cond), other=other)

def dropna(
self,
Expand Down
38 changes: 32 additions & 6 deletions tests/system/large/functions/test_managed_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -965,7 +965,7 @@ def float_parser(row):
)


def test_managed_function_df_where(session, dataset_id, scalars_dfs):
def test_managed_function_df_where_mask(session, dataset_id, scalars_dfs):
try:

# The return type has to be bool type for callable where condition.
Expand All @@ -987,15 +987,15 @@ def is_sum_positive(a, b):
pd_int64_df = scalars_pandas_df[int64_cols]
pd_int64_df_filtered = pd_int64_df.dropna()

# Use callable condition in dataframe.where method.
# Test callable condition in dataframe.where method.
bf_result = bf_int64_df_filtered.where(is_sum_positive_mf).to_pandas()
# Pandas doesn't support such case, use following as workaround.
pd_result = pd_int64_df_filtered.where(pd_int64_df_filtered.sum(axis=1) > 0)

# Ignore any dtype difference.
pandas.testing.assert_frame_equal(bf_result, pd_result, check_dtype=False)

# Make sure the read_gbq_function path works for this function.
# Make sure the read_gbq_function path works for dataframe.where method.
is_sum_positive_ref = session.read_gbq_function(
function_name=is_sum_positive_mf.bigframes_bigquery_function
)
Expand All @@ -1012,14 +1012,27 @@ def is_sum_positive(a, b):
bf_result_gbq, pd_result_gbq, check_dtype=False
)

# Test callable condition in dataframe.mask method.
bf_result_gbq = bf_int64_df_filtered.mask(
is_sum_positive_ref, -bf_int64_df_filtered
).to_pandas()
pd_result_gbq = pd_int64_df_filtered.mask(
pd_int64_df_filtered.sum(axis=1) > 0, -pd_int64_df_filtered
)

# Ignore any dtype difference.
pandas.testing.assert_frame_equal(
bf_result_gbq, pd_result_gbq, check_dtype=False
)

finally:
# Clean up the gcp assets created for the managed function.
cleanup_function_assets(
is_sum_positive_mf, session.bqclient, ignore_failures=False
)


def test_managed_function_df_where_series(session, dataset_id, scalars_dfs):
def test_managed_function_df_where_mask_series(session, dataset_id, scalars_dfs):
try:

# The return type has to be bool type for callable where condition.
Expand All @@ -1041,14 +1054,14 @@ def is_sum_positive_series(s):
pd_int64_df = scalars_pandas_df[int64_cols]
pd_int64_df_filtered = pd_int64_df.dropna()

# Use callable condition in dataframe.where method.
# Test callable condition in dataframe.where method.
bf_result = bf_int64_df_filtered.where(is_sum_positive_series).to_pandas()
pd_result = pd_int64_df_filtered.where(is_sum_positive_series)

# Ignore any dtype difference.
pandas.testing.assert_frame_equal(bf_result, pd_result, check_dtype=False)

# Make sure the read_gbq_function path works for this function.
# Make sure the read_gbq_function path works for dataframe.where method.
is_sum_positive_series_ref = session.read_gbq_function(
function_name=is_sum_positive_series_mf.bigframes_bigquery_function,
is_row_processor=True,
Expand All @@ -1070,6 +1083,19 @@ def func_for_other(x):
bf_result_gbq, pd_result_gbq, check_dtype=False
)

# Test callable condition in dataframe.mask method.
bf_result_gbq = bf_int64_df_filtered.mask(
is_sum_positive_series_ref, func_for_other
).to_pandas()
pd_result_gbq = pd_int64_df_filtered.mask(
is_sum_positive_series, func_for_other
)

# Ignore any dtype difference.
pandas.testing.assert_frame_equal(
bf_result_gbq, pd_result_gbq, check_dtype=False
)

finally:
# Clean up the gcp assets created for the managed function.
cleanup_function_assets(
Expand Down
25 changes: 21 additions & 4 deletions tests/system/large/functions/test_remote_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -2850,7 +2850,7 @@ def foo(x: int) -> int:


@pytest.mark.flaky(retries=2, delay=120)
def test_remote_function_df_where(session, dataset_id, scalars_dfs):
def test_remote_function_df_where_mask(session, dataset_id, scalars_dfs):
try:

# The return type has to be bool type for callable where condition.
Expand All @@ -2873,14 +2873,22 @@ def is_sum_positive(a, b):
pd_int64_df = scalars_pandas_df[int64_cols]
pd_int64_df_filtered = pd_int64_df.dropna()

# Use callable condition in dataframe.where method.
# Test callable condition in dataframe.where method.
bf_result = bf_int64_df_filtered.where(is_sum_positive_mf, 0).to_pandas()
# Pandas doesn't support such case, use following as workaround.
pd_result = pd_int64_df_filtered.where(pd_int64_df_filtered.sum(axis=1) > 0, 0)

# Ignore any dtype difference.
pandas.testing.assert_frame_equal(bf_result, pd_result, check_dtype=False)

# Test callable condition in dataframe.mask method.
bf_result = bf_int64_df_filtered.mask(is_sum_positive_mf, 0).to_pandas()
# Pandas doesn't support such case, use following as workaround.
pd_result = pd_int64_df_filtered.mask(pd_int64_df_filtered.sum(axis=1) > 0, 0)

# Ignore any dtype difference.
pandas.testing.assert_frame_equal(bf_result, pd_result, check_dtype=False)

finally:
# Clean up the gcp assets created for the remote function.
cleanup_function_assets(
Expand All @@ -2889,7 +2897,7 @@ def is_sum_positive(a, b):


@pytest.mark.flaky(retries=2, delay=120)
def test_remote_function_df_where_series(session, dataset_id, scalars_dfs):
def test_remote_function_df_where_mask_series(session, dataset_id, scalars_dfs):
try:

# The return type has to be bool type for callable where condition.
Expand All @@ -2916,7 +2924,7 @@ def is_sum_positive_series(s):
def func_for_other(x):
return -x

# Use callable condition in dataframe.where method.
# Test callable condition in dataframe.where method.
bf_result = bf_int64_df_filtered.where(
is_sum_positive_series, func_for_other
).to_pandas()
Expand All @@ -2925,6 +2933,15 @@ def func_for_other(x):
# Ignore any dtype difference.
pandas.testing.assert_frame_equal(bf_result, pd_result, check_dtype=False)

# Test callable condition in dataframe.mask method.
bf_result = bf_int64_df_filtered.mask(
is_sum_positive_series_mf, func_for_other
).to_pandas()
pd_result = pd_int64_df_filtered.mask(is_sum_positive_series, func_for_other)

# Ignore any dtype difference.
pandas.testing.assert_frame_equal(bf_result, pd_result, check_dtype=False)

finally:
# Clean up the gcp assets created for the remote function.
cleanup_function_assets(
Expand Down
12 changes: 12 additions & 0 deletions tests/system/small/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,18 @@ def test_mask_series_cond(scalars_df_index, scalars_pandas_df_index):
pandas.testing.assert_frame_equal(bf_result, pd_result)


def test_mask_callable(scalars_df_index, scalars_pandas_df_index):
def is_positive(x):
return x > 0

bf_df = scalars_df_index[["int64_too", "int64_col", "float64_col"]]
pd_df = scalars_pandas_df_index[["int64_too", "int64_col", "float64_col"]]
bf_result = bf_df.mask(cond=is_positive, other=lambda x: x + 1).to_pandas()
pd_result = pd_df.mask(cond=is_positive, other=lambda x: x + 1)

pandas.testing.assert_frame_equal(bf_result, pd_result)


def test_where_multi_column(scalars_df_index, scalars_pandas_df_index):
# Test when a dataframe has multi-columns.
columns = ["int64_col", "float64_col"]
Expand Down