Skip to content

Commit 4e76f53

Browse files
committed
feat: Support args in series apply method
1 parent b454256 commit 4e76f53

File tree

3 files changed

+101
-5
lines changed

3 files changed

+101
-5
lines changed

bigframes/series.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1898,9 +1898,13 @@ def _groupby_values(
18981898
)
18991899

19001900
def apply(
1901-
self, func, by_row: typing.Union[typing.Literal["compat"], bool] = "compat"
1901+
self,
1902+
func,
1903+
by_row: typing.Union[typing.Literal["compat"], bool] = "compat",
1904+
*,
1905+
args: typing.Tuple = (),
19021906
) -> Series:
1903-
# TODO(shobs, b/274645634): Support convert_dtype, args, **kwargs
1907+
# TODO(shobs, b/274645634): Support convert_dtype, **kwargs
19041908
# is actually a ternary op
19051909

19061910
if by_row not in ["compat", False]:
@@ -1944,10 +1948,16 @@ def apply(
19441948
raise
19451949

19461950
# We are working with bigquery function at this point
1947-
result_series = self._apply_unary_op(
1948-
ops.RemoteFunctionOp(function_def=func.udf_def, apply_on_null=True)
1949-
)
1951+
if args:
1952+
result_series = self._apply_nary_op(
1953+
ops.NaryRemoteFunctionOp(function_def=func.udf_def), args
1954+
)
1955+
else:
1956+
result_series = self._apply_unary_op(
1957+
ops.RemoteFunctionOp(function_def=func.udf_def, apply_on_null=True)
1958+
)
19501959
result_series = func._post_process_series(result_series)
1960+
19511961
return result_series
19521962

19531963
def combine(

tests/system/large/functions/test_managed_function.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1111,3 +1111,36 @@ def _is_positive(s):
11111111
finally:
11121112
# Clean up the gcp assets created for the managed function.
11131113
cleanup_function_assets(is_positive_mf, session.bqclient, ignore_failures=False)
1114+
1115+
1116+
def test_managed_function_series_apply_args(session, dataset_id, scalars_dfs):
1117+
try:
1118+
1119+
with pytest.warns(bfe.PreviewWarning, match="udf is in preview."):
1120+
1121+
@session.udf(dataset=dataset_id, name=prefixer.create_prefix())
1122+
def foo_list(x: int, y0: float, y1: bytes, y2: bool) -> list[str]:
1123+
return [str(x), str(y0), str(y1), str(y2)]
1124+
1125+
scalars_df, scalars_pandas_df = scalars_dfs
1126+
1127+
bf_result_col = scalars_df["int64_too"].apply(
1128+
foo_list, args=(12.34, b"hello world", False)
1129+
)
1130+
bf_result = (
1131+
scalars_df["int64_too"].to_frame().assign(result=bf_result_col).to_pandas()
1132+
)
1133+
1134+
pd_result_col = scalars_pandas_df["int64_too"].apply(
1135+
foo_list, args=(12.34, b"hello world", False)
1136+
)
1137+
pd_result = (
1138+
scalars_pandas_df["int64_too"].to_frame().assign(result=pd_result_col)
1139+
)
1140+
1141+
# Ignore any dtype difference.
1142+
pandas.testing.assert_frame_equal(bf_result, pd_result, check_dtype=False)
1143+
1144+
finally:
1145+
# Clean up the gcp assets created for the managed function.
1146+
cleanup_function_assets(foo_list, session.bqclient, ignore_failures=False)

tests/system/large/functions/test_remote_function.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2969,3 +2969,56 @@ def _ten_times(x):
29692969
finally:
29702970
# Clean up the gcp assets created for the remote function.
29712971
cleanup_function_assets(ten_times_mf, session.bqclient, ignore_failures=False)
2972+
2973+
2974+
@pytest.mark.flaky(retries=2, delay=120)
2975+
def test_remote_function_series_apply_args(session, dataset_id, scalars_dfs):
2976+
try:
2977+
2978+
@session.remote_function(
2979+
dataset=dataset_id,
2980+
reuse=False,
2981+
cloud_function_service_account="default",
2982+
)
2983+
def foo(x: int, y: bool, z: float) -> str:
2984+
if y:
2985+
return f"{x}: y is True."
2986+
if z > 0.0:
2987+
return f"{x}: y is False and z is positive."
2988+
return f"{x}: y is False and z is non-positive."
2989+
2990+
scalars_df, scalars_pandas_df = scalars_dfs
2991+
2992+
args1 = (True, 10.0)
2993+
bf_result_col = scalars_df["int64_too"].apply(foo, args=args1)
2994+
bf_result = (
2995+
scalars_df["int64_too"].to_frame().assign(result=bf_result_col).to_pandas()
2996+
)
2997+
2998+
pd_result_col = scalars_pandas_df["int64_too"].apply(foo, args=args1)
2999+
pd_result = (
3000+
scalars_pandas_df["int64_too"].to_frame().assign(result=pd_result_col)
3001+
)
3002+
3003+
# Ignore any dtype difference.
3004+
pandas.testing.assert_frame_equal(bf_result, pd_result, check_dtype=False)
3005+
3006+
args2 = (False, -10.0)
3007+
foo_ref = session.read_gbq_function(foo.bigframes_bigquery_function)
3008+
3009+
bf_result_col = scalars_df["int64_too"].apply(foo_ref, args=args2)
3010+
bf_result = (
3011+
scalars_df["int64_too"].to_frame().assign(result=bf_result_col).to_pandas()
3012+
)
3013+
3014+
pd_result_col = scalars_pandas_df["int64_too"].apply(foo, args=args2)
3015+
pd_result = (
3016+
scalars_pandas_df["int64_too"].to_frame().assign(result=pd_result_col)
3017+
)
3018+
3019+
# Ignore any dtype difference.
3020+
pandas.testing.assert_frame_equal(bf_result, pd_result, check_dtype=False)
3021+
3022+
finally:
3023+
# Clean up the gcp assets created for the remote function.
3024+
cleanup_function_assets(foo, session.bqclient, ignore_failures=False)

0 commit comments

Comments
 (0)