Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 27 additions & 5 deletions bigframes/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -1898,9 +1898,22 @@ def _groupby_values(
)

def apply(
self, func, by_row: typing.Union[typing.Literal["compat"], bool] = "compat"
self,
func,
by_row: typing.Union[typing.Literal["compat"], bool] = "compat",
*,
args: typing.Tuple = (),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In pandas args can be a positional argument while by_row is a keyword only argument https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.Series.apply.html#pandas-series-apply. I would argue that we can adhere to that right now as I don't anticipate many people passing by_row positional argument, but if not, we should add an item in BigFrames 3.0 to make the breaking change.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, also thought about it. It’s difficult to fully match pandas API here so far. I have added some comments for potential breaking changes in the future. However, given that this is an edge case, I recommend we defer any such changes until there is a clear need based on user feedback.

) -> Series:
# TODO(shobs, b/274645634): Support convert_dtype, args, **kwargs
# Note: This signature differs from pandas.Series.apply. Specifically,
# `args` is keyword-only and `by_row` is a custom parameter here. Full
# alignment would involve breaking changes. However, given that by_row
# is not frequently used, we defer any such changes until there is a
# clear need based on user feedback.
#
# See pandas docs for reference:
# https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.Series.apply.html

# TODO(shobs, b/274645634): Support convert_dtype, **kwargs
# is actually a ternary op

if by_row not in ["compat", False]:
Expand Down Expand Up @@ -1944,10 +1957,19 @@ def apply(
raise

# We are working with bigquery function at this point
result_series = self._apply_unary_op(
ops.RemoteFunctionOp(function_def=func.udf_def, apply_on_null=True)
)
if args:
result_series = self._apply_nary_op(
ops.NaryRemoteFunctionOp(function_def=func.udf_def), args
)
# TODO(jialuo): Investigate why `_apply_nary_op` drops the series
# `name`. Manually reassigning it here as a temporary fix.
result_series.name = self.name
else:
result_series = self._apply_unary_op(
ops.RemoteFunctionOp(function_def=func.udf_def, apply_on_null=True)
)
result_series = func._post_process_series(result_series)

return result_series

def combine(
Expand Down
28 changes: 28 additions & 0 deletions tests/system/large/functions/test_managed_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -1111,3 +1111,31 @@ def _is_positive(s):
finally:
# Clean up the gcp assets created for the managed function.
cleanup_function_assets(is_positive_mf, session.bqclient, ignore_failures=False)


def test_managed_function_series_apply_args(session, dataset_id, scalars_dfs):
try:

with pytest.warns(bfe.PreviewWarning, match="udf is in preview."):

@session.udf(dataset=dataset_id, name=prefixer.create_prefix())
def foo_list(x: int, y0: float, y1: bytes, y2: bool) -> list[str]:
return [str(x), str(y0), str(y1), str(y2)]

scalars_df, scalars_pandas_df = scalars_dfs

bf_result = (
scalars_df["int64_too"]
.apply(foo_list, args=(12.34, b"hello world", False))
.to_pandas()
)
pd_result = scalars_pandas_df["int64_too"].apply(
foo_list, args=(12.34, b"hello world", False)
)

# Ignore any dtype difference.
pandas.testing.assert_series_equal(bf_result, pd_result, check_dtype=False)

finally:
# Clean up the gcp assets created for the managed function.
cleanup_function_assets(foo_list, session.bqclient, ignore_failures=False)
39 changes: 39 additions & 0 deletions tests/system/large/functions/test_remote_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -2969,3 +2969,42 @@ def _ten_times(x):
finally:
# Clean up the gcp assets created for the remote function.
cleanup_function_assets(ten_times_mf, session.bqclient, ignore_failures=False)


@pytest.mark.flaky(retries=2, delay=120)
def test_remote_function_series_apply_args(session, dataset_id, scalars_dfs):
try:

@session.remote_function(
dataset=dataset_id,
reuse=False,
cloud_function_service_account="default",
)
def foo(x: int, y: bool, z: float) -> str:
if y:
return f"{x}: y is True."
if z > 0.0:
return f"{x}: y is False and z is positive."
return f"{x}: y is False and z is non-positive."

scalars_df, scalars_pandas_df = scalars_dfs

args1 = (True, 10.0)
bf_result = scalars_df["int64_too"].apply(foo, args=args1).to_pandas()
pd_result = scalars_pandas_df["int64_too"].apply(foo, args=args1)

# Ignore any dtype difference.
pandas.testing.assert_series_equal(bf_result, pd_result, check_dtype=False)

args2 = (False, -10.0)
foo_ref = session.read_gbq_function(foo.bigframes_bigquery_function)

bf_result = scalars_df["int64_too"].apply(foo_ref, args=args2).to_pandas()
pd_result = scalars_pandas_df["int64_too"].apply(foo, args=args2)

# Ignore any dtype difference.
pandas.testing.assert_series_equal(bf_result, pd_result, check_dtype=False)

finally:
# Clean up the gcp assets created for the remote function.
cleanup_function_assets(foo, session.bqclient, ignore_failures=False)