diff --git a/bigframes/series.py b/bigframes/series.py index 6f48935ec9..2005581ea4 100644 --- a/bigframes/series.py +++ b/bigframes/series.py @@ -1898,9 +1898,22 @@ def _groupby_values( ) def apply( - self, func, by_row: typing.Union[typing.Literal["compat"], bool] = "compat" + self, + func, + by_row: typing.Union[typing.Literal["compat"], bool] = "compat", + *, + args: typing.Tuple = (), ) -> Series: - # TODO(shobs, b/274645634): Support convert_dtype, args, **kwargs + # Note: This signature differs from pandas.Series.apply. Specifically, + # `args` is keyword-only and `by_row` is a custom parameter here. Full + # alignment would involve breaking changes. However, given that by_row + # is not frequently used, we defer any such changes until there is a + # clear need based on user feedback. + # + # See pandas docs for reference: + # https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.Series.apply.html + + # TODO(shobs, b/274645634): Support convert_dtype, **kwargs # is actually a ternary op if by_row not in ["compat", False]: @@ -1944,10 +1957,19 @@ def apply( raise # We are working with bigquery function at this point - result_series = self._apply_unary_op( - ops.RemoteFunctionOp(function_def=func.udf_def, apply_on_null=True) - ) + if args: + result_series = self._apply_nary_op( + ops.NaryRemoteFunctionOp(function_def=func.udf_def), args + ) + # TODO(jialuo): Investigate why `_apply_nary_op` drops the series + # `name`. Manually reassigning it here as a temporary fix. + result_series.name = self.name + else: + result_series = self._apply_unary_op( + ops.RemoteFunctionOp(function_def=func.udf_def, apply_on_null=True) + ) result_series = func._post_process_series(result_series) + return result_series def combine( diff --git a/tests/system/large/functions/test_managed_function.py b/tests/system/large/functions/test_managed_function.py index 262f5f0fe2..74b24a60b3 100644 --- a/tests/system/large/functions/test_managed_function.py +++ b/tests/system/large/functions/test_managed_function.py @@ -1111,3 +1111,31 @@ def _is_positive(s): finally: # Clean up the gcp assets created for the managed function. cleanup_function_assets(is_positive_mf, session.bqclient, ignore_failures=False) + + +def test_managed_function_series_apply_args(session, dataset_id, scalars_dfs): + try: + + with pytest.warns(bfe.PreviewWarning, match="udf is in preview."): + + @session.udf(dataset=dataset_id, name=prefixer.create_prefix()) + def foo_list(x: int, y0: float, y1: bytes, y2: bool) -> list[str]: + return [str(x), str(y0), str(y1), str(y2)] + + scalars_df, scalars_pandas_df = scalars_dfs + + bf_result = ( + scalars_df["int64_too"] + .apply(foo_list, args=(12.34, b"hello world", False)) + .to_pandas() + ) + pd_result = scalars_pandas_df["int64_too"].apply( + foo_list, args=(12.34, b"hello world", False) + ) + + # Ignore any dtype difference. + pandas.testing.assert_series_equal(bf_result, pd_result, check_dtype=False) + + finally: + # Clean up the gcp assets created for the managed function. + cleanup_function_assets(foo_list, session.bqclient, ignore_failures=False) diff --git a/tests/system/large/functions/test_remote_function.py b/tests/system/large/functions/test_remote_function.py index 9e2c1e2c81..4518bddb80 100644 --- a/tests/system/large/functions/test_remote_function.py +++ b/tests/system/large/functions/test_remote_function.py @@ -2969,3 +2969,42 @@ def _ten_times(x): finally: # Clean up the gcp assets created for the remote function. cleanup_function_assets(ten_times_mf, session.bqclient, ignore_failures=False) + + +@pytest.mark.flaky(retries=2, delay=120) +def test_remote_function_series_apply_args(session, dataset_id, scalars_dfs): + try: + + @session.remote_function( + dataset=dataset_id, + reuse=False, + cloud_function_service_account="default", + ) + def foo(x: int, y: bool, z: float) -> str: + if y: + return f"{x}: y is True." + if z > 0.0: + return f"{x}: y is False and z is positive." + return f"{x}: y is False and z is non-positive." + + scalars_df, scalars_pandas_df = scalars_dfs + + args1 = (True, 10.0) + bf_result = scalars_df["int64_too"].apply(foo, args=args1).to_pandas() + pd_result = scalars_pandas_df["int64_too"].apply(foo, args=args1) + + # Ignore any dtype difference. + pandas.testing.assert_series_equal(bf_result, pd_result, check_dtype=False) + + args2 = (False, -10.0) + foo_ref = session.read_gbq_function(foo.bigframes_bigquery_function) + + bf_result = scalars_df["int64_too"].apply(foo_ref, args=args2).to_pandas() + pd_result = scalars_pandas_df["int64_too"].apply(foo, args=args2) + + # Ignore any dtype difference. + pandas.testing.assert_series_equal(bf_result, pd_result, check_dtype=False) + + finally: + # Clean up the gcp assets created for the remote function. + cleanup_function_assets(foo, session.bqclient, ignore_failures=False)