Skip to content

Commit 44c1ec4

Browse files
authored
feat: Support callable bigframes function for dataframe where (#1990)
1 parent b692713 commit 44c1ec4

File tree

3 files changed

+205
-3
lines changed

3 files changed

+205
-3
lines changed

bigframes/dataframe.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2797,10 +2797,17 @@ def where(self, cond, other=None):
27972797
)
27982798

27992799
# Execute it with the DataFrame when cond or/and other is callable.
2800+
# It can be either a plain python function or remote/managed function.
28002801
if callable(cond):
2801-
cond = cond(self)
2802+
if hasattr(cond, "bigframes_bigquery_function"):
2803+
cond = self.apply(cond, axis=1)
2804+
else:
2805+
cond = cond(self)
28022806
if callable(other):
2803-
other = other(self)
2807+
if hasattr(other, "bigframes_bigquery_function"):
2808+
other = self.apply(other, axis=1)
2809+
else:
2810+
other = other(self)
28042811

28052812
aligned_block, (_, _) = self._block.join(cond._block, how="left")
28062813
# No left join is needed when 'other' is None or constant.
@@ -2813,7 +2820,7 @@ def where(self, cond, other=None):
28132820
labels = aligned_block.column_labels[:self_len]
28142821
self_col = {x: ex.deref(y) for x, y in zip(labels, ids)}
28152822

2816-
if isinstance(cond, bigframes.series.Series) and cond.name in self_col:
2823+
if isinstance(cond, bigframes.series.Series):
28172824
# This is when 'cond' is a valid series.
28182825
y = aligned_block.value_columns[self_len]
28192826
cond_col = {x: ex.deref(y) for x in self_col.keys()}

tests/system/large/functions/test_managed_function.py

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -963,3 +963,115 @@ def float_parser(row):
963963
cleanup_function_assets(
964964
float_parser_mf, session.bqclient, ignore_failures=False
965965
)
966+
967+
968+
def test_managed_function_df_where(session, dataset_id, scalars_dfs):
969+
try:
970+
971+
# The return type has to be bool type for callable where condition.
972+
def is_sum_positive(a, b):
973+
return a + b > 0
974+
975+
is_sum_positive_mf = session.udf(
976+
input_types=[int, int],
977+
output_type=bool,
978+
dataset=dataset_id,
979+
name=prefixer.create_prefix(),
980+
)(is_sum_positive)
981+
982+
scalars_df, scalars_pandas_df = scalars_dfs
983+
int64_cols = ["int64_col", "int64_too"]
984+
985+
bf_int64_df = scalars_df[int64_cols]
986+
bf_int64_df_filtered = bf_int64_df.dropna()
987+
pd_int64_df = scalars_pandas_df[int64_cols]
988+
pd_int64_df_filtered = pd_int64_df.dropna()
989+
990+
# Use callable condition in dataframe.where method.
991+
bf_result = bf_int64_df_filtered.where(is_sum_positive_mf).to_pandas()
992+
# Pandas doesn't support such case, use following as workaround.
993+
pd_result = pd_int64_df_filtered.where(pd_int64_df_filtered.sum(axis=1) > 0)
994+
995+
# Ignore any dtype difference.
996+
pandas.testing.assert_frame_equal(bf_result, pd_result, check_dtype=False)
997+
998+
# Make sure the read_gbq_function path works for this function.
999+
is_sum_positive_ref = session.read_gbq_function(
1000+
function_name=is_sum_positive_mf.bigframes_bigquery_function
1001+
)
1002+
1003+
bf_result_gbq = bf_int64_df_filtered.where(
1004+
is_sum_positive_ref, -bf_int64_df_filtered
1005+
).to_pandas()
1006+
pd_result_gbq = pd_int64_df_filtered.where(
1007+
pd_int64_df_filtered.sum(axis=1) > 0, -pd_int64_df_filtered
1008+
)
1009+
1010+
# Ignore any dtype difference.
1011+
pandas.testing.assert_frame_equal(
1012+
bf_result_gbq, pd_result_gbq, check_dtype=False
1013+
)
1014+
1015+
finally:
1016+
# Clean up the gcp assets created for the managed function.
1017+
cleanup_function_assets(
1018+
is_sum_positive_mf, session.bqclient, ignore_failures=False
1019+
)
1020+
1021+
1022+
def test_managed_function_df_where_series(session, dataset_id, scalars_dfs):
1023+
try:
1024+
1025+
# The return type has to be bool type for callable where condition.
1026+
def is_sum_positive_series(s):
1027+
return s["int64_col"] + s["int64_too"] > 0
1028+
1029+
is_sum_positive_series_mf = session.udf(
1030+
input_types=bigframes.series.Series,
1031+
output_type=bool,
1032+
dataset=dataset_id,
1033+
name=prefixer.create_prefix(),
1034+
)(is_sum_positive_series)
1035+
1036+
scalars_df, scalars_pandas_df = scalars_dfs
1037+
int64_cols = ["int64_col", "int64_too"]
1038+
1039+
bf_int64_df = scalars_df[int64_cols]
1040+
bf_int64_df_filtered = bf_int64_df.dropna()
1041+
pd_int64_df = scalars_pandas_df[int64_cols]
1042+
pd_int64_df_filtered = pd_int64_df.dropna()
1043+
1044+
# Use callable condition in dataframe.where method.
1045+
bf_result = bf_int64_df_filtered.where(is_sum_positive_series).to_pandas()
1046+
pd_result = pd_int64_df_filtered.where(is_sum_positive_series)
1047+
1048+
# Ignore any dtype difference.
1049+
pandas.testing.assert_frame_equal(bf_result, pd_result, check_dtype=False)
1050+
1051+
# Make sure the read_gbq_function path works for this function.
1052+
is_sum_positive_series_ref = session.read_gbq_function(
1053+
function_name=is_sum_positive_series_mf.bigframes_bigquery_function,
1054+
is_row_processor=True,
1055+
)
1056+
1057+
# This is for callable `other` arg in dataframe.where method.
1058+
def func_for_other(x):
1059+
return -x
1060+
1061+
bf_result_gbq = bf_int64_df_filtered.where(
1062+
is_sum_positive_series_ref, func_for_other
1063+
).to_pandas()
1064+
pd_result_gbq = pd_int64_df_filtered.where(
1065+
is_sum_positive_series, func_for_other
1066+
)
1067+
1068+
# Ignore any dtype difference.
1069+
pandas.testing.assert_frame_equal(
1070+
bf_result_gbq, pd_result_gbq, check_dtype=False
1071+
)
1072+
1073+
finally:
1074+
# Clean up the gcp assets created for the managed function.
1075+
cleanup_function_assets(
1076+
is_sum_positive_series_mf, session.bqclient, ignore_failures=False
1077+
)

tests/system/large/functions/test_remote_function.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2847,3 +2847,86 @@ def foo(x: int) -> int:
28472847
finally:
28482848
# clean up the gcp assets created for the remote function
28492849
cleanup_function_assets(foo, session.bqclient, session.cloudfunctionsclient)
2850+
2851+
2852+
@pytest.mark.flaky(retries=2, delay=120)
2853+
def test_remote_function_df_where(session, dataset_id, scalars_dfs):
2854+
try:
2855+
2856+
# The return type has to be bool type for callable where condition.
2857+
def is_sum_positive(a, b):
2858+
return a + b > 0
2859+
2860+
is_sum_positive_mf = session.remote_function(
2861+
input_types=[int, int],
2862+
output_type=bool,
2863+
dataset=dataset_id,
2864+
reuse=False,
2865+
cloud_function_service_account="default",
2866+
)(is_sum_positive)
2867+
2868+
scalars_df, scalars_pandas_df = scalars_dfs
2869+
int64_cols = ["int64_col", "int64_too"]
2870+
2871+
bf_int64_df = scalars_df[int64_cols]
2872+
bf_int64_df_filtered = bf_int64_df.dropna()
2873+
pd_int64_df = scalars_pandas_df[int64_cols]
2874+
pd_int64_df_filtered = pd_int64_df.dropna()
2875+
2876+
# Use callable condition in dataframe.where method.
2877+
bf_result = bf_int64_df_filtered.where(is_sum_positive_mf, 0).to_pandas()
2878+
# Pandas doesn't support such case, use following as workaround.
2879+
pd_result = pd_int64_df_filtered.where(pd_int64_df_filtered.sum(axis=1) > 0, 0)
2880+
2881+
# Ignore any dtype difference.
2882+
pandas.testing.assert_frame_equal(bf_result, pd_result, check_dtype=False)
2883+
2884+
finally:
2885+
# Clean up the gcp assets created for the remote function.
2886+
cleanup_function_assets(
2887+
is_sum_positive_mf, session.bqclient, ignore_failures=False
2888+
)
2889+
2890+
2891+
@pytest.mark.flaky(retries=2, delay=120)
2892+
def test_remote_function_df_where_series(session, dataset_id, scalars_dfs):
2893+
try:
2894+
2895+
# The return type has to be bool type for callable where condition.
2896+
def is_sum_positive_series(s):
2897+
return s["int64_col"] + s["int64_too"] > 0
2898+
2899+
is_sum_positive_series_mf = session.remote_function(
2900+
input_types=bigframes.series.Series,
2901+
output_type=bool,
2902+
dataset=dataset_id,
2903+
reuse=False,
2904+
cloud_function_service_account="default",
2905+
)(is_sum_positive_series)
2906+
2907+
scalars_df, scalars_pandas_df = scalars_dfs
2908+
int64_cols = ["int64_col", "int64_too"]
2909+
2910+
bf_int64_df = scalars_df[int64_cols]
2911+
bf_int64_df_filtered = bf_int64_df.dropna()
2912+
pd_int64_df = scalars_pandas_df[int64_cols]
2913+
pd_int64_df_filtered = pd_int64_df.dropna()
2914+
2915+
# This is for callable `other` arg in dataframe.where method.
2916+
def func_for_other(x):
2917+
return -x
2918+
2919+
# Use callable condition in dataframe.where method.
2920+
bf_result = bf_int64_df_filtered.where(
2921+
is_sum_positive_series, func_for_other
2922+
).to_pandas()
2923+
pd_result = pd_int64_df_filtered.where(is_sum_positive_series, func_for_other)
2924+
2925+
# Ignore any dtype difference.
2926+
pandas.testing.assert_frame_equal(bf_result, pd_result, check_dtype=False)
2927+
2928+
finally:
2929+
# Clean up the gcp assets created for the remote function.
2930+
cleanup_function_assets(
2931+
is_sum_positive_series_mf, session.bqclient, ignore_failures=False
2932+
)

0 commit comments

Comments
 (0)