Skip to content

Commit 164c481

Browse files
authored
feat: Support args in dataframe apply method (#2026)
* feat: Allow passing args to managed functions in DataFrame apply method * remove a test * support remote function * resolve the comments * improve the message * fix the tests
1 parent 7072627 commit 164c481

File tree

7 files changed

+320
-56
lines changed

7 files changed

+320
-56
lines changed

bigframes/dataframe.py

Lines changed: 55 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@
7777
import bigframes.exceptions as bfe
7878
import bigframes.formatting_helpers as formatter
7979
import bigframes.functions
80+
from bigframes.functions import function_typing
8081
import bigframes.operations as ops
8182
import bigframes.operations.aggregations as agg_ops
8283
import bigframes.operations.ai
@@ -4835,37 +4836,73 @@ def apply(self, func, *, axis=0, args: typing.Tuple = (), **kwargs):
48354836
)
48364837

48374838
# Apply the function
4838-
result_series = rows_as_json_series._apply_unary_op(
4839-
ops.RemoteFunctionOp(function_def=func.udf_def, apply_on_null=True)
4840-
)
4839+
if args:
4840+
result_series = rows_as_json_series._apply_nary_op(
4841+
ops.NaryRemoteFunctionOp(function_def=func.udf_def),
4842+
list(args),
4843+
)
4844+
else:
4845+
result_series = rows_as_json_series._apply_unary_op(
4846+
ops.RemoteFunctionOp(
4847+
function_def=func.udf_def, apply_on_null=True
4848+
)
4849+
)
48414850
else:
48424851
# This is a special case where we are providing not-pandas-like
48434852
# extension. If the bigquery function can take one or more
4844-
# params then we assume that here the user intention is to use
4845-
# the column values of the dataframe as arguments to the
4846-
# function. For this to work the following condition must be
4847-
# true:
4848-
# 1. The number or input params in the function must be same
4849-
# as the number of columns in the dataframe
4853+
# params (excluding the args) then we assume that here the user
4854+
# intention is to use the column values of the dataframe as
4855+
# arguments to the function. For this to work the following
4856+
# condition must be true:
4857+
# 1. The number or input params (excluding the args) in the
4858+
# function must be same as the number of columns in the
4859+
# dataframe.
48504860
# 2. The dtypes of the columns in the dataframe must be
4851-
# compatible with the data types of the input params
4861+
# compatible with the data types of the input params.
48524862
# 3. The order of the columns in the dataframe must correspond
4853-
# to the order of the input params in the function
4863+
# to the order of the input params in the function.
48544864
udf_input_dtypes = func.udf_def.signature.bf_input_types
4855-
if len(udf_input_dtypes) != len(self.columns):
4865+
if not args and len(udf_input_dtypes) != len(self.columns):
48564866
raise ValueError(
4857-
f"BigFrames BigQuery function takes {len(udf_input_dtypes)}"
4858-
f" arguments but DataFrame has {len(self.columns)} columns."
4867+
f"Parameter count mismatch: BigFrames BigQuery function"
4868+
f" expected {len(udf_input_dtypes)} parameters but"
4869+
f" received {len(self.columns)} DataFrame columns."
48594870
)
4860-
if udf_input_dtypes != tuple(self.dtypes.to_list()):
4871+
if args and len(udf_input_dtypes) != len(self.columns) + len(args):
48614872
raise ValueError(
4862-
f"BigFrames BigQuery function takes arguments of types "
4863-
f"{udf_input_dtypes} but DataFrame dtypes are {tuple(self.dtypes)}."
4873+
f"Parameter count mismatch: BigFrames BigQuery function"
4874+
f" expected {len(udf_input_dtypes)} parameters but"
4875+
f" received {len(self.columns) + len(args)} values"
4876+
f" ({len(self.columns)} DataFrame columns and"
4877+
f" {len(args)} args)."
48644878
)
4879+
end_slice = -len(args) if args else None
4880+
if udf_input_dtypes[:end_slice] != tuple(self.dtypes.to_list()):
4881+
raise ValueError(
4882+
f"Data type mismatch for DataFrame columns:"
4883+
f" Expected {udf_input_dtypes[:end_slice]}"
4884+
f" Received {tuple(self.dtypes)}."
4885+
)
4886+
if args:
4887+
bq_types = (
4888+
function_typing.sdk_type_from_python_type(type(arg))
4889+
for arg in args
4890+
)
4891+
args_dtype = tuple(
4892+
function_typing.sdk_type_to_bf_type(bq_type)
4893+
for bq_type in bq_types
4894+
)
4895+
if udf_input_dtypes[end_slice:] != args_dtype:
4896+
raise ValueError(
4897+
f"Data type mismatch for 'args' parameter:"
4898+
f" Expected {udf_input_dtypes[end_slice:]}"
4899+
f" Received {args_dtype}."
4900+
)
48654901

48664902
series_list = [self[col] for col in self.columns]
4903+
op_list = series_list[1:] + list(args)
48674904
result_series = series_list[0]._apply_nary_op(
4868-
ops.NaryRemoteFunctionOp(function_def=func.udf_def), series_list[1:]
4905+
ops.NaryRemoteFunctionOp(function_def=func.udf_def), op_list
48694906
)
48704907
result_series.name = None
48714908

bigframes/functions/_function_session.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -959,11 +959,16 @@ def _convert_row_processor_sig(
959959
) -> Optional[inspect.Signature]:
960960
import bigframes.series as bf_series
961961

962-
if len(signature.parameters) == 1:
963-
only_param = next(iter(signature.parameters.values()))
964-
param_type = only_param.annotation
962+
if len(signature.parameters) >= 1:
963+
first_param = next(iter(signature.parameters.values()))
964+
param_type = first_param.annotation
965965
if (param_type == bf_series.Series) or (param_type == pandas.Series):
966966
msg = bfe.format_message("input_types=Series is in preview.")
967967
warnings.warn(msg, stacklevel=1, category=bfe.PreviewWarning)
968-
return signature.replace(parameters=[only_param.replace(annotation=str)])
968+
return signature.replace(
969+
parameters=[
970+
p.replace(annotation=str) if i == 0 else p
971+
for i, p in enumerate(signature.parameters.values())
972+
]
973+
)
969974
return None

bigframes/functions/function.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -178,13 +178,6 @@ def read_gbq_function(
178178
ValueError, f"Unknown function '{routine_ref}'."
179179
)
180180

181-
if is_row_processor and len(routine.arguments) > 1:
182-
raise bf_formatting.create_exception_with_feedback_link(
183-
ValueError,
184-
"A multi-input function cannot be a row processor. A row processor function "
185-
"takes in a single input representing the row.",
186-
)
187-
188181
if is_row_processor:
189182
return _try_import_row_routine(routine, session)
190183
else:

bigframes/functions/function_template.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,9 @@ def udf_http_row_processor(request):
195195
calls = request_json["calls"]
196196
replies = []
197197
for call in calls:
198-
reply = convert_to_bq_json(output_type, udf(get_pd_series(call[0])))
198+
reply = convert_to_bq_json(
199+
output_type, udf(get_pd_series(call[0]), *call[1:])
200+
)
199201
if type(reply) is list:
200202
# Since the BQ remote function does not support array yet,
201203
# return a json serialized version of the reply.
@@ -332,6 +334,28 @@ def generate_managed_function_code(
332334
f"""def bigframes_handler(str_arg):
333335
return {udf_name}({get_pd_series.__name__}(str_arg))"""
334336
)
337+
338+
sig = inspect.signature(def_)
339+
params = list(sig.parameters.values())
340+
additional_params = params[1:]
341+
342+
# Build the parameter list for the new handler function definition.
343+
# e.g., "str_arg, y: bool, z"
344+
handler_def_parts = ["str_arg"]
345+
handler_def_parts.extend(str(p) for p in additional_params)
346+
handler_def_str = ", ".join(handler_def_parts)
347+
348+
# Build the argument list for the call to the original UDF.
349+
# e.g., "get_pd_series(str_arg), y, z"
350+
udf_call_parts = [f"{get_pd_series.__name__}(str_arg)"]
351+
udf_call_parts.extend(p.name for p in additional_params)
352+
udf_call_str = ", ".join(udf_call_parts)
353+
354+
bigframes_handler_code = textwrap.dedent(
355+
f"""def bigframes_handler({handler_def_str}):
356+
return {udf_name}({udf_call_str})"""
357+
)
358+
335359
else:
336360
udf_code = ""
337361
bigframes_handler_code = textwrap.dedent(

tests/system/large/functions/test_managed_function.py

Lines changed: 114 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -468,20 +468,20 @@ def foo(x, y, z):
468468
# Fails to apply on dataframe with incompatible number of columns.
469469
with pytest.raises(
470470
ValueError,
471-
match="^BigFrames BigQuery function takes 3 arguments but DataFrame has 2 columns\\.$",
471+
match="^Parameter count mismatch:.* expected 3 parameters but received 2 DataFrame columns.",
472472
):
473473
bf_df[["Id", "Age"]].apply(foo, axis=1)
474474

475475
with pytest.raises(
476476
ValueError,
477-
match="^BigFrames BigQuery function takes 3 arguments but DataFrame has 4 columns\\.$",
477+
match="^Parameter count mismatch:.* expected 3 parameters but received 4 DataFrame columns.",
478478
):
479479
bf_df.assign(Country="lalaland").apply(foo, axis=1)
480480

481481
# Fails to apply on dataframe with incompatible column datatypes.
482482
with pytest.raises(
483483
ValueError,
484-
match="^BigFrames BigQuery function takes arguments of types .* but DataFrame dtypes are .*",
484+
match="^Data type mismatch for DataFrame columns: Expected .* Received .*",
485485
):
486486
bf_df.assign(Age=bf_df["Age"].astype("Int64")).apply(foo, axis=1)
487487

@@ -965,6 +965,117 @@ def float_parser(row):
965965
)
966966

967967

968+
def test_managed_function_df_apply_axis_1_args(session, dataset_id, scalars_dfs):
969+
columns = ["int64_col", "int64_too"]
970+
scalars_df, scalars_pandas_df = scalars_dfs
971+
972+
try:
973+
974+
def the_sum(s1, s2, x):
975+
return s1 + s2 + x
976+
977+
the_sum_mf = session.udf(
978+
input_types=[int, int, int],
979+
output_type=int,
980+
dataset=dataset_id,
981+
name=prefixer.create_prefix(),
982+
)(the_sum)
983+
984+
args1 = (1,)
985+
986+
# Fails to apply on dataframe with incompatible number of columns and args.
987+
with pytest.raises(
988+
ValueError,
989+
match="^Parameter count mismatch:.* expected 3 parameters but received 4 values \\(3 DataFrame columns and 1 args\\)",
990+
):
991+
scalars_df[columns + ["float64_col"]].apply(the_sum_mf, axis=1, args=args1)
992+
993+
# Fails to apply on dataframe with incompatible column datatypes.
994+
with pytest.raises(
995+
ValueError,
996+
match="^Data type mismatch for DataFrame columns: Expected .* Received .*",
997+
):
998+
scalars_df[columns].assign(
999+
int64_col=lambda df: df["int64_col"].astype("Float64")
1000+
).apply(the_sum_mf, axis=1, args=args1)
1001+
1002+
# Fails to apply on dataframe with incompatible args datatypes.
1003+
with pytest.raises(
1004+
ValueError,
1005+
match="^Data type mismatch for 'args' parameter: Expected .* Received .*",
1006+
):
1007+
scalars_df[columns].apply(the_sum_mf, axis=1, args=(1.3,))
1008+
1009+
bf_result = (
1010+
scalars_df[columns]
1011+
.dropna()
1012+
.apply(the_sum_mf, axis=1, args=args1)
1013+
.to_pandas()
1014+
)
1015+
pd_result = scalars_pandas_df[columns].dropna().apply(sum, axis=1, args=args1)
1016+
1017+
pandas.testing.assert_series_equal(pd_result, bf_result, check_dtype=False)
1018+
1019+
finally:
1020+
# clean up the gcp assets created for the managed function.
1021+
cleanup_function_assets(the_sum_mf, session.bqclient, ignore_failures=False)
1022+
1023+
1024+
def test_managed_function_df_apply_axis_1_series_args(session, dataset_id, scalars_dfs):
1025+
columns = ["int64_col", "float64_col"]
1026+
scalars_df, scalars_pandas_df = scalars_dfs
1027+
1028+
try:
1029+
1030+
def analyze(s, x, y):
1031+
value = f"value is {s['int64_col']} and {s['float64_col']}"
1032+
if x:
1033+
return f"{value}, x is True!"
1034+
if y > 0:
1035+
return f"{value}, x is False, y is positive!"
1036+
return f"{value}, x is False, y is non-positive!"
1037+
1038+
analyze_mf = session.udf(
1039+
input_types=[bigframes.series.Series, bool, float],
1040+
output_type=str,
1041+
dataset=dataset_id,
1042+
name=prefixer.create_prefix(),
1043+
)(analyze)
1044+
1045+
args1 = (True, 10.0)
1046+
bf_result = (
1047+
scalars_df[columns]
1048+
.dropna()
1049+
.apply(analyze_mf, axis=1, args=args1)
1050+
.to_pandas()
1051+
)
1052+
pd_result = (
1053+
scalars_pandas_df[columns].dropna().apply(analyze, axis=1, args=args1)
1054+
)
1055+
1056+
pandas.testing.assert_series_equal(pd_result, bf_result, check_dtype=False)
1057+
1058+
args2 = (False, -10.0)
1059+
analyze_mf_ref = session.read_gbq_function(
1060+
analyze_mf.bigframes_bigquery_function, is_row_processor=True
1061+
)
1062+
bf_result = (
1063+
scalars_df[columns]
1064+
.dropna()
1065+
.apply(analyze_mf_ref, axis=1, args=args2)
1066+
.to_pandas()
1067+
)
1068+
pd_result = (
1069+
scalars_pandas_df[columns].dropna().apply(analyze, axis=1, args=args2)
1070+
)
1071+
1072+
pandas.testing.assert_series_equal(pd_result, bf_result, check_dtype=False)
1073+
1074+
finally:
1075+
# clean up the gcp assets created for the managed function.
1076+
cleanup_function_assets(analyze_mf, session.bqclient, ignore_failures=False)
1077+
1078+
9681079
def test_managed_function_df_where_mask(session, dataset_id, scalars_dfs):
9691080
try:
9701081

0 commit comments

Comments
 (0)