diff --git a/bigframes/dataframe.py b/bigframes/dataframe.py index 85b8245272..6b5029f200 100644 --- a/bigframes/dataframe.py +++ b/bigframes/dataframe.py @@ -76,6 +76,7 @@ import bigframes.exceptions as bfe import bigframes.formatting_helpers as formatter import bigframes.functions +from bigframes.functions import function_typing import bigframes.operations as ops import bigframes.operations.aggregations as agg_ops import bigframes.operations.ai @@ -4801,37 +4802,73 @@ def apply(self, func, *, axis=0, args: typing.Tuple = (), **kwargs): ) # Apply the function - result_series = rows_as_json_series._apply_unary_op( - ops.RemoteFunctionOp(function_def=func.udf_def, apply_on_null=True) - ) + if args: + result_series = rows_as_json_series._apply_nary_op( + ops.NaryRemoteFunctionOp(function_def=func.udf_def), + list(args), + ) + else: + result_series = rows_as_json_series._apply_unary_op( + ops.RemoteFunctionOp( + function_def=func.udf_def, apply_on_null=True + ) + ) else: # This is a special case where we are providing not-pandas-like # extension. If the bigquery function can take one or more - # params then we assume that here the user intention is to use - # the column values of the dataframe as arguments to the - # function. For this to work the following condition must be - # true: - # 1. The number or input params in the function must be same - # as the number of columns in the dataframe + # params (excluding the args) then we assume that here the user + # intention is to use the column values of the dataframe as + # arguments to the function. For this to work the following + # condition must be true: + # 1. The number or input params (excluding the args) in the + # function must be same as the number of columns in the + # dataframe. # 2. The dtypes of the columns in the dataframe must be - # compatible with the data types of the input params + # compatible with the data types of the input params. # 3. The order of the columns in the dataframe must correspond - # to the order of the input params in the function + # to the order of the input params in the function. udf_input_dtypes = func.udf_def.signature.bf_input_types - if len(udf_input_dtypes) != len(self.columns): + if not args and len(udf_input_dtypes) != len(self.columns): raise ValueError( - f"BigFrames BigQuery function takes {len(udf_input_dtypes)}" - f" arguments but DataFrame has {len(self.columns)} columns." + f"Parameter count mismatch: BigFrames BigQuery function" + f" expected {len(udf_input_dtypes)} parameters but" + f" received {len(self.columns)} DataFrame columns." ) - if udf_input_dtypes != tuple(self.dtypes.to_list()): + if args and len(udf_input_dtypes) != len(self.columns) + len(args): raise ValueError( - f"BigFrames BigQuery function takes arguments of types " - f"{udf_input_dtypes} but DataFrame dtypes are {tuple(self.dtypes)}." + f"Parameter count mismatch: BigFrames BigQuery function" + f" expected {len(udf_input_dtypes)} parameters but" + f" received {len(self.columns) + len(args)} values" + f" ({len(self.columns)} DataFrame columns and" + f" {len(args)} args)." ) + end_slice = -len(args) if args else None + if udf_input_dtypes[:end_slice] != tuple(self.dtypes.to_list()): + raise ValueError( + f"Data type mismatch for DataFrame columns:" + f" Expected {udf_input_dtypes[:end_slice]}" + f" Received {tuple(self.dtypes)}." + ) + if args: + bq_types = ( + function_typing.sdk_type_from_python_type(type(arg)) + for arg in args + ) + args_dtype = tuple( + function_typing.sdk_type_to_bf_type(bq_type) + for bq_type in bq_types + ) + if udf_input_dtypes[end_slice:] != args_dtype: + raise ValueError( + f"Data type mismatch for 'args' parameter:" + f" Expected {udf_input_dtypes[end_slice:]}" + f" Received {args_dtype}." + ) series_list = [self[col] for col in self.columns] + op_list = series_list[1:] + list(args) result_series = series_list[0]._apply_nary_op( - ops.NaryRemoteFunctionOp(function_def=func.udf_def), series_list[1:] + ops.NaryRemoteFunctionOp(function_def=func.udf_def), op_list ) result_series.name = None diff --git a/bigframes/functions/_function_session.py b/bigframes/functions/_function_session.py index 90bfb89c56..a2fb66539b 100644 --- a/bigframes/functions/_function_session.py +++ b/bigframes/functions/_function_session.py @@ -959,11 +959,16 @@ def _convert_row_processor_sig( ) -> Optional[inspect.Signature]: import bigframes.series as bf_series - if len(signature.parameters) == 1: - only_param = next(iter(signature.parameters.values())) - param_type = only_param.annotation + if len(signature.parameters) >= 1: + first_param = next(iter(signature.parameters.values())) + param_type = first_param.annotation if (param_type == bf_series.Series) or (param_type == pandas.Series): msg = bfe.format_message("input_types=Series is in preview.") warnings.warn(msg, stacklevel=1, category=bfe.PreviewWarning) - return signature.replace(parameters=[only_param.replace(annotation=str)]) + return signature.replace( + parameters=[ + p.replace(annotation=str) if i == 0 else p + for i, p in enumerate(signature.parameters.values()) + ] + ) return None diff --git a/bigframes/functions/function.py b/bigframes/functions/function.py index a62da57075..99b89131e7 100644 --- a/bigframes/functions/function.py +++ b/bigframes/functions/function.py @@ -178,13 +178,6 @@ def read_gbq_function( ValueError, f"Unknown function '{routine_ref}'." ) - if is_row_processor and len(routine.arguments) > 1: - raise bf_formatting.create_exception_with_feedback_link( - ValueError, - "A multi-input function cannot be a row processor. A row processor function " - "takes in a single input representing the row.", - ) - if is_row_processor: return _try_import_row_routine(routine, session) else: diff --git a/bigframes/functions/function_template.py b/bigframes/functions/function_template.py index 5f04fcc8e2..dd31de7243 100644 --- a/bigframes/functions/function_template.py +++ b/bigframes/functions/function_template.py @@ -195,7 +195,9 @@ def udf_http_row_processor(request): calls = request_json["calls"] replies = [] for call in calls: - reply = convert_to_bq_json(output_type, udf(get_pd_series(call[0]))) + reply = convert_to_bq_json( + output_type, udf(get_pd_series(call[0]), *call[1:]) + ) if type(reply) is list: # Since the BQ remote function does not support array yet, # return a json serialized version of the reply. @@ -332,6 +334,28 @@ def generate_managed_function_code( f"""def bigframes_handler(str_arg): return {udf_name}({get_pd_series.__name__}(str_arg))""" ) + + sig = inspect.signature(def_) + params = list(sig.parameters.values()) + additional_params = params[1:] + + # Build the parameter list for the new handler function definition. + # e.g., "str_arg, y: bool, z" + handler_def_parts = ["str_arg"] + handler_def_parts.extend(str(p) for p in additional_params) + handler_def_str = ", ".join(handler_def_parts) + + # Build the argument list for the call to the original UDF. + # e.g., "get_pd_series(str_arg), y, z" + udf_call_parts = [f"{get_pd_series.__name__}(str_arg)"] + udf_call_parts.extend(p.name for p in additional_params) + udf_call_str = ", ".join(udf_call_parts) + + bigframes_handler_code = textwrap.dedent( + f"""def bigframes_handler({handler_def_str}): + return {udf_name}({udf_call_str})""" + ) + else: udf_code = "" bigframes_handler_code = textwrap.dedent( diff --git a/tests/system/large/functions/test_managed_function.py b/tests/system/large/functions/test_managed_function.py index 73335afa3c..b0e44b648f 100644 --- a/tests/system/large/functions/test_managed_function.py +++ b/tests/system/large/functions/test_managed_function.py @@ -468,20 +468,20 @@ def foo(x, y, z): # Fails to apply on dataframe with incompatible number of columns. with pytest.raises( ValueError, - match="^BigFrames BigQuery function takes 3 arguments but DataFrame has 2 columns\\.$", + match="^Parameter count mismatch:.* expected 3 parameters but received 2 DataFrame columns.", ): bf_df[["Id", "Age"]].apply(foo, axis=1) with pytest.raises( ValueError, - match="^BigFrames BigQuery function takes 3 arguments but DataFrame has 4 columns\\.$", + match="^Parameter count mismatch:.* expected 3 parameters but received 4 DataFrame columns.", ): bf_df.assign(Country="lalaland").apply(foo, axis=1) # Fails to apply on dataframe with incompatible column datatypes. with pytest.raises( ValueError, - match="^BigFrames BigQuery function takes arguments of types .* but DataFrame dtypes are .*", + match="^Data type mismatch for DataFrame columns: Expected .* Received .*", ): bf_df.assign(Age=bf_df["Age"].astype("Int64")).apply(foo, axis=1) @@ -965,6 +965,117 @@ def float_parser(row): ) +def test_managed_function_df_apply_axis_1_args(session, dataset_id, scalars_dfs): + columns = ["int64_col", "int64_too"] + scalars_df, scalars_pandas_df = scalars_dfs + + try: + + def the_sum(s1, s2, x): + return s1 + s2 + x + + the_sum_mf = session.udf( + input_types=[int, int, int], + output_type=int, + dataset=dataset_id, + name=prefixer.create_prefix(), + )(the_sum) + + args1 = (1,) + + # Fails to apply on dataframe with incompatible number of columns and args. + with pytest.raises( + ValueError, + match="^Parameter count mismatch:.* expected 3 parameters but received 4 values \\(3 DataFrame columns and 1 args\\)", + ): + scalars_df[columns + ["float64_col"]].apply(the_sum_mf, axis=1, args=args1) + + # Fails to apply on dataframe with incompatible column datatypes. + with pytest.raises( + ValueError, + match="^Data type mismatch for DataFrame columns: Expected .* Received .*", + ): + scalars_df[columns].assign( + int64_col=lambda df: df["int64_col"].astype("Float64") + ).apply(the_sum_mf, axis=1, args=args1) + + # Fails to apply on dataframe with incompatible args datatypes. + with pytest.raises( + ValueError, + match="^Data type mismatch for 'args' parameter: Expected .* Received .*", + ): + scalars_df[columns].apply(the_sum_mf, axis=1, args=(1.3,)) + + bf_result = ( + scalars_df[columns] + .dropna() + .apply(the_sum_mf, axis=1, args=args1) + .to_pandas() + ) + pd_result = scalars_pandas_df[columns].dropna().apply(sum, axis=1, args=args1) + + pandas.testing.assert_series_equal(pd_result, bf_result, check_dtype=False) + + finally: + # clean up the gcp assets created for the managed function. + cleanup_function_assets(the_sum_mf, session.bqclient, ignore_failures=False) + + +def test_managed_function_df_apply_axis_1_series_args(session, dataset_id, scalars_dfs): + columns = ["int64_col", "float64_col"] + scalars_df, scalars_pandas_df = scalars_dfs + + try: + + def analyze(s, x, y): + value = f"value is {s['int64_col']} and {s['float64_col']}" + if x: + return f"{value}, x is True!" + if y > 0: + return f"{value}, x is False, y is positive!" + return f"{value}, x is False, y is non-positive!" + + analyze_mf = session.udf( + input_types=[bigframes.series.Series, bool, float], + output_type=str, + dataset=dataset_id, + name=prefixer.create_prefix(), + )(analyze) + + args1 = (True, 10.0) + bf_result = ( + scalars_df[columns] + .dropna() + .apply(analyze_mf, axis=1, args=args1) + .to_pandas() + ) + pd_result = ( + scalars_pandas_df[columns].dropna().apply(analyze, axis=1, args=args1) + ) + + pandas.testing.assert_series_equal(pd_result, bf_result, check_dtype=False) + + args2 = (False, -10.0) + analyze_mf_ref = session.read_gbq_function( + analyze_mf.bigframes_bigquery_function, is_row_processor=True + ) + bf_result = ( + scalars_df[columns] + .dropna() + .apply(analyze_mf_ref, axis=1, args=args2) + .to_pandas() + ) + pd_result = ( + scalars_pandas_df[columns].dropna().apply(analyze, axis=1, args=args2) + ) + + pandas.testing.assert_series_equal(pd_result, bf_result, check_dtype=False) + + finally: + # clean up the gcp assets created for the managed function. + cleanup_function_assets(analyze_mf, session.bqclient, ignore_failures=False) + + def test_managed_function_df_where_mask(session, dataset_id, scalars_dfs): try: diff --git a/tests/system/large/functions/test_remote_function.py b/tests/system/large/functions/test_remote_function.py index 3c453a52a4..e6372d768b 100644 --- a/tests/system/large/functions/test_remote_function.py +++ b/tests/system/large/functions/test_remote_function.py @@ -1937,6 +1937,114 @@ def float_parser(row): ) +@pytest.mark.flaky(retries=2, delay=120) +def test_df_apply_axis_1_args(session, scalars_dfs): + columns = ["int64_col", "int64_too"] + scalars_df, scalars_pandas_df = scalars_dfs + + try: + + def the_sum(s1, s2, x): + return s1 + s2 + x + + the_sum_mf = session.remote_function( + input_types=[int, int, int], + output_type=int, + reuse=False, + cloud_function_service_account="default", + )(the_sum) + + args1 = (1,) + + # Fails to apply on dataframe with incompatible number of columns and args. + with pytest.raises( + ValueError, + match="^Parameter count mismatch:.* expected 3 parameters but received 4 values \\(2 DataFrame columns and 2 args\\)", + ): + scalars_df[columns].apply( + the_sum_mf, + axis=1, + args=( + 1, + 1, + ), + ) + + # Fails to apply on dataframe with incompatible column datatypes. + with pytest.raises( + ValueError, + match="^Data type mismatch for DataFrame columns: Expected .* Received .*", + ): + scalars_df[columns].assign( + int64_col=lambda df: df["int64_col"].astype("Float64") + ).apply(the_sum_mf, axis=1, args=args1) + + # Fails to apply on dataframe with incompatible args datatypes. + with pytest.raises( + ValueError, + match="^Data type mismatch for 'args' parameter: Expected .* Received .*", + ): + scalars_df[columns].apply(the_sum_mf, axis=1, args=("hello world",)) + + bf_result = ( + scalars_df[columns] + .dropna() + .apply(the_sum_mf, axis=1, args=args1) + .to_pandas() + ) + pd_result = scalars_pandas_df[columns].dropna().apply(sum, axis=1, args=args1) + + pandas.testing.assert_series_equal(pd_result, bf_result, check_dtype=False) + + finally: + # clean up the gcp assets created for the remote function. + cleanup_function_assets(the_sum_mf, session.bqclient, ignore_failures=False) + + +@pytest.mark.flaky(retries=2, delay=120) +def test_df_apply_axis_1_series_args(session, scalars_dfs): + columns = ["int64_col", "float64_col"] + scalars_df, scalars_pandas_df = scalars_dfs + + try: + + @session.remote_function( + input_types=[bigframes.series.Series, float, str, bool], + output_type=list[str], + reuse=False, + cloud_function_service_account="default", + ) + def foo_list(x, y0: float, y1, y2) -> list[str]: + return ( + [str(x["int64_col"]), str(y0), str(y1), str(y2)] + if y2 + else [str(x["float64_col"])] + ) + + args1 = (12.34, "hello world", True) + bf_result = scalars_df[columns].apply(foo_list, axis=1, args=args1).to_pandas() + pd_result = scalars_pandas_df[columns].apply(foo_list, axis=1, args=args1) + + # Ignore any dtype difference. + pandas.testing.assert_series_equal(bf_result, pd_result, check_dtype=False) + + args2 = (43.21, "xxx3yyy", False) + foo_list_ref = session.read_gbq_function( + foo_list.bigframes_bigquery_function, is_row_processor=True + ) + bf_result = ( + scalars_df[columns].apply(foo_list_ref, axis=1, args=args2).to_pandas() + ) + pd_result = scalars_pandas_df[columns].apply(foo_list, axis=1, args=args2) + + # Ignore any dtype difference. + pandas.testing.assert_series_equal(bf_result, pd_result, check_dtype=False) + + finally: + # Clean up the gcp assets created for the remote function. + cleanup_function_assets(foo_list, session.bqclient, ignore_failures=False) + + @pytest.mark.parametrize( ("memory_mib_args", "expected_memory"), [ @@ -2200,19 +2308,19 @@ def foo(x, y, z): # Fails to apply on dataframe with incompatible number of columns with pytest.raises( ValueError, - match="^BigFrames BigQuery function takes 3 arguments but DataFrame has 2 columns\\.$", + match="^Parameter count mismatch:.* expected 3 parameters but received 2 DataFrame columns.", ): bf_df[["Id", "Age"]].apply(foo, axis=1) with pytest.raises( ValueError, - match="^BigFrames BigQuery function takes 3 arguments but DataFrame has 4 columns\\.$", + match="^Parameter count mismatch:.* expected 3 parameters but received 4 DataFrame columns.", ): bf_df.assign(Country="lalaland").apply(foo, axis=1) # Fails to apply on dataframe with incompatible column datatypes with pytest.raises( ValueError, - match="^BigFrames BigQuery function takes arguments of types .* but DataFrame dtypes are .*", + match="^Data type mismatch for DataFrame columns: Expected .* Received .*", ): bf_df.assign(Age=bf_df["Age"].astype("Int64")).apply(foo, axis=1) @@ -2284,19 +2392,19 @@ def foo(x, y, z): # Fails to apply on dataframe with incompatible number of columns with pytest.raises( ValueError, - match="^BigFrames BigQuery function takes 3 arguments but DataFrame has 2 columns\\.$", + match="^Parameter count mismatch:.* expected 3 parameters but received 2 DataFrame columns.", ): bf_df[["Id", "Age"]].apply(foo, axis=1) with pytest.raises( ValueError, - match="^BigFrames BigQuery function takes 3 arguments but DataFrame has 4 columns\\.$", + match="^Parameter count mismatch:.* expected 3 parameters but received 4 DataFrame columns.", ): bf_df.assign(Country="lalaland").apply(foo, axis=1) # Fails to apply on dataframe with incompatible column datatypes with pytest.raises( ValueError, - match="^BigFrames BigQuery function takes arguments of types .* but DataFrame dtypes are .*", + match="^Data type mismatch for DataFrame columns: Expected .* Received .*", ): bf_df.assign(Age=bf_df["Age"].astype("Int64")).apply(foo, axis=1) @@ -2358,19 +2466,19 @@ def foo(x): # Fails to apply on dataframe with incompatible number of columns with pytest.raises( ValueError, - match="^BigFrames BigQuery function takes 1 arguments but DataFrame has 0 columns\\.$", + match="^Parameter count mismatch:.* expected 1 parameters but received 0 DataFrame.*", ): bf_df[[]].apply(foo, axis=1) with pytest.raises( ValueError, - match="^BigFrames BigQuery function takes 1 arguments but DataFrame has 2 columns\\.$", + match="^Parameter count mismatch:.* expected 1 parameters but received 2 DataFrame.*", ): bf_df.assign(Country="lalaland").apply(foo, axis=1) # Fails to apply on dataframe with incompatible column datatypes with pytest.raises( ValueError, - match="^BigFrames BigQuery function takes arguments of types .* but DataFrame dtypes are .*", + match="^Data type mismatch for DataFrame columns: Expected .* Received .*", ): bf_df.assign(Id=bf_df["Id"].astype("Float64")).apply(foo, axis=1) diff --git a/tests/system/small/functions/test_remote_function.py b/tests/system/small/functions/test_remote_function.py index 86076e764f..28fab19144 100644 --- a/tests/system/small/functions/test_remote_function.py +++ b/tests/system/small/functions/test_remote_function.py @@ -1154,20 +1154,6 @@ def test_df_apply_scalar_func(session, scalars_dfs): ) -def test_read_gbq_function_multiple_inputs_not_a_row_processor(session): - with pytest.raises(ValueError) as context: - # The remote function has two args, which cannot be row processed. Throw - # a ValueError for it. - session.read_gbq_function( - function_name="bqutil.fn.cw_regexp_instr_2", - is_row_processor=True, - ) - assert str(context.value) == ( - "A multi-input function cannot be a row processor. A row processor function " - f"takes in a single input representing the row. {constants.FEEDBACK_LINK}" - ) - - @pytest.mark.flaky(retries=2, delay=120) def test_df_apply_axis_1(session, scalars_dfs, dataset_id_permanent): columns = [