Skip to content

Commit da9938c

Browse files
committed
support remote function
1 parent 4c4f807 commit da9938c

File tree

3 files changed

+114
-1
lines changed

3 files changed

+114
-1
lines changed

bigframes/functions/function_template.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,9 @@ def udf_http_row_processor(request):
195195
calls = request_json["calls"]
196196
replies = []
197197
for call in calls:
198-
reply = convert_to_bq_json(output_type, udf(get_pd_series(call[0])))
198+
reply = convert_to_bq_json(
199+
output_type, udf(get_pd_series(call[0]), *call[1:])
200+
)
199201
if type(reply) is list:
200202
# Since the BQ remote function does not support array yet,
201203
# return a json serialized version of the reply.

tests/system/large/functions/test_managed_function.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -982,6 +982,23 @@ def the_sum(s1, s2, x):
982982
)(the_sum)
983983

984984
args1 = (1,)
985+
986+
# Fails to apply on dataframe with incompatible number of columns.
987+
with pytest.raises(
988+
ValueError,
989+
match="^Column count mismatch: BigFrames BigQuery function expected 2 columns from DataFrame but received 3\\.$",
990+
):
991+
scalars_df[columns + ["float64_col"]].apply(the_sum_mf, axis=1, args=args1)
992+
993+
# Fails to apply on dataframe with incompatible column datatypes.
994+
with pytest.raises(
995+
ValueError,
996+
match="^Data type mismatch: BigFrames BigQuery function takes arguments of types .* but DataFrame dtypes are .*",
997+
):
998+
scalars_df[columns].assign(
999+
int64_col=lambda df: df["int64_col"].astype("Float64")
1000+
).apply(the_sum_mf, axis=1, args=args1)
1001+
9851002
bf_result = (
9861003
scalars_df[columns]
9871004
.dropna()

tests/system/large/functions/test_remote_function.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1937,6 +1937,100 @@ def float_parser(row):
19371937
)
19381938

19391939

1940+
@pytest.mark.flaky(retries=2, delay=120)
1941+
def test_df_apply_axis_1_args(session, scalars_dfs):
1942+
columns = ["int64_col", "int64_too"]
1943+
scalars_df, scalars_pandas_df = scalars_dfs
1944+
1945+
try:
1946+
1947+
def the_sum(s1, s2, x):
1948+
return s1 + s2 + x
1949+
1950+
the_sum_mf = session.remote_function(
1951+
input_types=[int, int, int],
1952+
output_type=int,
1953+
reuse=False,
1954+
cloud_function_service_account="default",
1955+
)(the_sum)
1956+
1957+
args1 = (1,)
1958+
1959+
# Fails to apply on dataframe with incompatible number of columns.
1960+
with pytest.raises(
1961+
ValueError,
1962+
match="^Column count mismatch: BigFrames BigQuery function expected 2 columns from DataFrame but received 3\\.$",
1963+
):
1964+
scalars_df[columns + ["float64_col"]].apply(the_sum_mf, axis=1, args=args1)
1965+
1966+
# Fails to apply on dataframe with incompatible column datatypes.
1967+
with pytest.raises(
1968+
ValueError,
1969+
match="^Data type mismatch: BigFrames BigQuery function takes arguments of types .* but DataFrame dtypes are .*",
1970+
):
1971+
scalars_df[columns].assign(
1972+
int64_col=lambda df: df["int64_col"].astype("Float64")
1973+
).apply(the_sum_mf, axis=1, args=args1)
1974+
1975+
bf_result = (
1976+
scalars_df[columns]
1977+
.dropna()
1978+
.apply(the_sum_mf, axis=1, args=args1)
1979+
.to_pandas()
1980+
)
1981+
pd_result = scalars_pandas_df[columns].dropna().apply(sum, axis=1, args=args1)
1982+
1983+
pandas.testing.assert_series_equal(pd_result, bf_result, check_dtype=False)
1984+
1985+
finally:
1986+
# clean up the gcp assets created for the remote function.
1987+
cleanup_function_assets(the_sum_mf, session.bqclient, ignore_failures=False)
1988+
1989+
1990+
@pytest.mark.flaky(retries=2, delay=120)
1991+
def test_df_apply_axis_1_series_args(session, scalars_dfs):
1992+
columns = ["int64_col", "float64_col"]
1993+
scalars_df, scalars_pandas_df = scalars_dfs
1994+
1995+
try:
1996+
1997+
@session.remote_function(
1998+
input_types=[bigframes.series.Series, float, str, bool],
1999+
output_type=list[str],
2000+
reuse=False,
2001+
cloud_function_service_account="default",
2002+
)
2003+
def foo_list(x, y0: float, y1, y2) -> list[str]:
2004+
return (
2005+
[str(x["int64_col"]), str(y0), str(y1), str(y2)]
2006+
if y2
2007+
else [str(x["float64_col"])]
2008+
)
2009+
2010+
args1 = (12.34, "hello world", True)
2011+
bf_result = scalars_df[columns].apply(foo_list, axis=1, args=args1).to_pandas()
2012+
pd_result = scalars_pandas_df[columns].apply(foo_list, axis=1, args=args1)
2013+
2014+
# Ignore any dtype difference.
2015+
pandas.testing.assert_series_equal(bf_result, pd_result, check_dtype=False)
2016+
2017+
args2 = (43.21, "xxx3yyy", False)
2018+
foo_list_ref = session.read_gbq_function(
2019+
foo_list.bigframes_bigquery_function, is_row_processor=True
2020+
)
2021+
bf_result = (
2022+
scalars_df[columns].apply(foo_list_ref, axis=1, args=args2).to_pandas()
2023+
)
2024+
pd_result = scalars_pandas_df[columns].apply(foo_list, axis=1, args=args2)
2025+
2026+
# Ignore any dtype difference.
2027+
pandas.testing.assert_series_equal(bf_result, pd_result, check_dtype=False)
2028+
2029+
finally:
2030+
# Clean up the gcp assets created for the remote function.
2031+
cleanup_function_assets(foo_list, session.bqclient, ignore_failures=False)
2032+
2033+
19402034
@pytest.mark.parametrize(
19412035
("memory_mib_args", "expected_memory"),
19422036
[

0 commit comments

Comments
 (0)