Skip to content

Commit d9d725c

Browse files
authored
feat: Support args in series apply method (#2013)
* feat: Support args in series apply method * resolve the comments
1 parent e300ed1 commit d9d725c

File tree

3 files changed

+94
-5
lines changed

3 files changed

+94
-5
lines changed

bigframes/series.py

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1904,9 +1904,22 @@ def _groupby_values(
19041904
)
19051905

19061906
def apply(
1907-
self, func, by_row: typing.Union[typing.Literal["compat"], bool] = "compat"
1907+
self,
1908+
func,
1909+
by_row: typing.Union[typing.Literal["compat"], bool] = "compat",
1910+
*,
1911+
args: typing.Tuple = (),
19081912
) -> Series:
1909-
# TODO(shobs, b/274645634): Support convert_dtype, args, **kwargs
1913+
# Note: This signature differs from pandas.Series.apply. Specifically,
1914+
# `args` is keyword-only and `by_row` is a custom parameter here. Full
1915+
# alignment would involve breaking changes. However, given that by_row
1916+
# is not frequently used, we defer any such changes until there is a
1917+
# clear need based on user feedback.
1918+
#
1919+
# See pandas docs for reference:
1920+
# https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.Series.apply.html
1921+
1922+
# TODO(shobs, b/274645634): Support convert_dtype, **kwargs
19101923
# is actually a ternary op
19111924

19121925
if by_row not in ["compat", False]:
@@ -1950,10 +1963,19 @@ def apply(
19501963
raise
19511964

19521965
# We are working with bigquery function at this point
1953-
result_series = self._apply_unary_op(
1954-
ops.RemoteFunctionOp(function_def=func.udf_def, apply_on_null=True)
1955-
)
1966+
if args:
1967+
result_series = self._apply_nary_op(
1968+
ops.NaryRemoteFunctionOp(function_def=func.udf_def), args
1969+
)
1970+
# TODO(jialuo): Investigate why `_apply_nary_op` drops the series
1971+
# `name`. Manually reassigning it here as a temporary fix.
1972+
result_series.name = self.name
1973+
else:
1974+
result_series = self._apply_unary_op(
1975+
ops.RemoteFunctionOp(function_def=func.udf_def, apply_on_null=True)
1976+
)
19561977
result_series = func._post_process_series(result_series)
1978+
19571979
return result_series
19581980

19591981
def combine(

tests/system/large/functions/test_managed_function.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1121,3 +1121,31 @@ def _is_positive(s):
11211121
finally:
11221122
# Clean up the gcp assets created for the managed function.
11231123
cleanup_function_assets(is_positive_mf, session.bqclient, ignore_failures=False)
1124+
1125+
1126+
def test_managed_function_series_apply_args(session, dataset_id, scalars_dfs):
1127+
try:
1128+
1129+
with pytest.warns(bfe.PreviewWarning, match="udf is in preview."):
1130+
1131+
@session.udf(dataset=dataset_id, name=prefixer.create_prefix())
1132+
def foo_list(x: int, y0: float, y1: bytes, y2: bool) -> list[str]:
1133+
return [str(x), str(y0), str(y1), str(y2)]
1134+
1135+
scalars_df, scalars_pandas_df = scalars_dfs
1136+
1137+
bf_result = (
1138+
scalars_df["int64_too"]
1139+
.apply(foo_list, args=(12.34, b"hello world", False))
1140+
.to_pandas()
1141+
)
1142+
pd_result = scalars_pandas_df["int64_too"].apply(
1143+
foo_list, args=(12.34, b"hello world", False)
1144+
)
1145+
1146+
# Ignore any dtype difference.
1147+
pandas.testing.assert_series_equal(bf_result, pd_result, check_dtype=False)
1148+
1149+
finally:
1150+
# Clean up the gcp assets created for the managed function.
1151+
cleanup_function_assets(foo_list, session.bqclient, ignore_failures=False)

tests/system/large/functions/test_remote_function.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2979,3 +2979,42 @@ def _ten_times(x):
29792979
finally:
29802980
# Clean up the gcp assets created for the remote function.
29812981
cleanup_function_assets(ten_times_mf, session.bqclient, ignore_failures=False)
2982+
2983+
2984+
@pytest.mark.flaky(retries=2, delay=120)
2985+
def test_remote_function_series_apply_args(session, dataset_id, scalars_dfs):
2986+
try:
2987+
2988+
@session.remote_function(
2989+
dataset=dataset_id,
2990+
reuse=False,
2991+
cloud_function_service_account="default",
2992+
)
2993+
def foo(x: int, y: bool, z: float) -> str:
2994+
if y:
2995+
return f"{x}: y is True."
2996+
if z > 0.0:
2997+
return f"{x}: y is False and z is positive."
2998+
return f"{x}: y is False and z is non-positive."
2999+
3000+
scalars_df, scalars_pandas_df = scalars_dfs
3001+
3002+
args1 = (True, 10.0)
3003+
bf_result = scalars_df["int64_too"].apply(foo, args=args1).to_pandas()
3004+
pd_result = scalars_pandas_df["int64_too"].apply(foo, args=args1)
3005+
3006+
# Ignore any dtype difference.
3007+
pandas.testing.assert_series_equal(bf_result, pd_result, check_dtype=False)
3008+
3009+
args2 = (False, -10.0)
3010+
foo_ref = session.read_gbq_function(foo.bigframes_bigquery_function)
3011+
3012+
bf_result = scalars_df["int64_too"].apply(foo_ref, args=args2).to_pandas()
3013+
pd_result = scalars_pandas_df["int64_too"].apply(foo, args=args2)
3014+
3015+
# Ignore any dtype difference.
3016+
pandas.testing.assert_series_equal(bf_result, pd_result, check_dtype=False)
3017+
3018+
finally:
3019+
# Clean up the gcp assets created for the remote function.
3020+
cleanup_function_assets(foo, session.bqclient, ignore_failures=False)

0 commit comments

Comments
 (0)