Skip to content

Commit 62a189f

Browse files
authored
feat: support series input in managed function (#1920)
* feat: support series input in managed function * resolve the comments * resolve the comments * resolve the comments * fix message
1 parent 190f32e commit 62a189f

File tree

6 files changed

+393
-38
lines changed

6 files changed

+393
-38
lines changed

bigframes/exceptions.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,13 @@ class FunctionAxisOnePreviewWarning(PreviewWarning):
103103
"""Remote Function and Managed UDF with axis=1 preview."""
104104

105105

106+
class FunctionPackageVersionWarning(PreviewWarning):
107+
"""
108+
Managed UDF package versions for Numpy, Pandas, and Pyarrow may not
109+
precisely match users' local environment or the exact versions specified.
110+
"""
111+
112+
106113
def format_message(message: str, fill: bool = True):
107114
"""Formats a warning message with ANSI color codes for the warning color.
108115

bigframes/functions/_function_client.py

Lines changed: 9 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
import logging
2020
import os
2121
import random
22-
import re
2322
import shutil
2423
import string
2524
import tempfile
@@ -247,7 +246,7 @@ def provision_bq_managed_function(
247246
# Augment user package requirements with any internal package
248247
# requirements.
249248
packages = _utils._get_updated_package_requirements(
250-
packages, is_row_processor, capture_references
249+
packages, is_row_processor, capture_references, ignore_package_version=True
251250
)
252251
if packages:
253252
managed_function_options["packages"] = packages
@@ -270,28 +269,6 @@ def provision_bq_managed_function(
270269
)
271270

272271
udf_name = func.__name__
273-
if capture_references:
274-
# This code path ensures that if the udf body contains any
275-
# references to variables and/or imports outside the body, they are
276-
# captured as well.
277-
import cloudpickle
278-
279-
pickled = cloudpickle.dumps(func)
280-
udf_code = textwrap.dedent(
281-
f"""
282-
import cloudpickle
283-
{udf_name} = cloudpickle.loads({pickled})
284-
"""
285-
)
286-
else:
287-
# This code path ensures that if the udf body is self contained,
288-
# i.e. there are no references to variables or imports outside the
289-
# body.
290-
udf_code = textwrap.dedent(inspect.getsource(func))
291-
match = re.search(r"^def ", udf_code, flags=re.MULTILINE)
292-
if match is None:
293-
raise ValueError("The UDF is not defined correctly.")
294-
udf_code = udf_code[match.start() :]
295272

296273
with_connection_clause = (
297274
(
@@ -301,6 +278,13 @@ def provision_bq_managed_function(
301278
else ""
302279
)
303280

281+
# Generate the complete Python code block for the managed Python UDF,
282+
# including the user's function, necessary imports, and the BigQuery
283+
# handler wrapper.
284+
python_code_block = bff_template.generate_managed_function_code(
285+
func, udf_name, is_row_processor, capture_references
286+
)
287+
304288
create_function_ddl = (
305289
textwrap.dedent(
306290
f"""
@@ -311,13 +295,11 @@ def provision_bq_managed_function(
311295
OPTIONS ({managed_function_options_str})
312296
AS r'''
313297
__UDF_PLACE_HOLDER__
314-
def bigframes_handler(*args):
315-
return {udf_name}(*args)
316298
'''
317299
"""
318300
)
319301
.strip()
320-
.replace("__UDF_PLACE_HOLDER__", udf_code)
302+
.replace("__UDF_PLACE_HOLDER__", python_code_block)
321303
)
322304

323305
self._ensure_dataset_exists()

bigframes/functions/_function_session.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -847,15 +847,15 @@ def wrapper(func):
847847
if output_type:
848848
py_sig = py_sig.replace(return_annotation=output_type)
849849

850-
udf_sig = udf_def.UdfSignature.from_py_signature(py_sig)
851-
852850
# The function will actually be receiving a pandas Series, but allow
853851
# both BigQuery DataFrames and pandas object types for compatibility.
854852
is_row_processor = False
855853
if new_sig := _convert_row_processor_sig(py_sig):
856854
py_sig = new_sig
857855
is_row_processor = True
858856

857+
udf_sig = udf_def.UdfSignature.from_py_signature(py_sig)
858+
859859
managed_function_client = _function_client.FunctionClient(
860860
dataset_ref.project,
861861
bq_location,

bigframes/functions/_utils.py

Lines changed: 30 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import sys
1919
import typing
2020
from typing import cast, Optional, Set
21+
import warnings
2122

2223
import cloudpickle
2324
import google.api_core.exceptions
@@ -26,6 +27,7 @@
2627
import pandas
2728
import pyarrow
2829

30+
import bigframes.exceptions as bfe
2931
import bigframes.formatting_helpers as bf_formatting
3032
from bigframes.functions import function_typing
3133

@@ -61,21 +63,40 @@ def get_remote_function_locations(bq_location):
6163

6264

6365
def _get_updated_package_requirements(
64-
package_requirements=None, is_row_processor=False, capture_references=True
66+
package_requirements=None,
67+
is_row_processor=False,
68+
capture_references=True,
69+
ignore_package_version=False,
6570
):
6671
requirements = []
6772
if capture_references:
6873
requirements.append(f"cloudpickle=={cloudpickle.__version__}")
6974

7075
if is_row_processor:
71-
# bigframes function will send an entire row of data as json, which
72-
# would be converted to a pandas series and processed Ensure numpy
73-
# versions match to avoid unpickling problems. See internal issue
74-
# b/347934471.
75-
requirements.append(f"numpy=={numpy.__version__}")
76-
requirements.append(f"pandas=={pandas.__version__}")
77-
requirements.append(f"pyarrow=={pyarrow.__version__}")
78-
76+
if ignore_package_version:
77+
# TODO(jialuo): Add back the version after b/410924784 is resolved.
78+
# Due to current limitations on the packages version in Python UDFs,
79+
# we use `ignore_package_version` to optionally omit the version for
80+
# managed functions only.
81+
msg = bfe.format_message(
82+
"numpy, pandas, and pyarrow versions in the function execution"
83+
" environment may not precisely match your local environment."
84+
)
85+
warnings.warn(msg, category=bfe.FunctionPackageVersionWarning)
86+
requirements.append("pandas")
87+
requirements.append("pyarrow")
88+
requirements.append("numpy")
89+
else:
90+
# bigframes function will send an entire row of data as json, which
91+
# would be converted to a pandas series and processed Ensure numpy
92+
# versions match to avoid unpickling problems. See internal issue
93+
# b/347934471.
94+
requirements.append(f"pandas=={pandas.__version__}")
95+
requirements.append(f"pyarrow=={pyarrow.__version__}")
96+
requirements.append(f"numpy=={numpy.__version__}")
97+
98+
# TODO(b/435023957): Fix the issue of potential duplicate package versions
99+
# when `package_requirements` also contains `pandas/pyarrow/numpy`.
79100
if package_requirements:
80101
requirements.extend(package_requirements)
81102

bigframes/functions/function_template.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import inspect
1818
import logging
1919
import os
20+
import re
2021
import textwrap
2122
from typing import Tuple
2223

@@ -291,3 +292,55 @@ def generate_cloud_function_main_code(
291292
logger.debug(f"Wrote {os.path.abspath(main_py)}:\n{open(main_py).read()}")
292293

293294
return handler_func_name
295+
296+
297+
def generate_managed_function_code(
298+
def_,
299+
udf_name: str,
300+
is_row_processor: bool,
301+
capture_references: bool,
302+
) -> str:
303+
"""Generates the Python code block for managed Python UDF."""
304+
305+
if capture_references:
306+
# This code path ensures that if the udf body contains any
307+
# references to variables and/or imports outside the body, they are
308+
# captured as well.
309+
import cloudpickle
310+
311+
pickled = cloudpickle.dumps(def_)
312+
func_code = textwrap.dedent(
313+
f"""
314+
import cloudpickle
315+
{udf_name} = cloudpickle.loads({pickled})
316+
"""
317+
)
318+
else:
319+
# This code path ensures that if the udf body is self contained,
320+
# i.e. there are no references to variables or imports outside the
321+
# body.
322+
func_code = textwrap.dedent(inspect.getsource(def_))
323+
match = re.search(r"^def ", func_code, flags=re.MULTILINE)
324+
if match is None:
325+
raise ValueError("The UDF is not defined correctly.")
326+
func_code = func_code[match.start() :]
327+
328+
if is_row_processor:
329+
udf_code = textwrap.dedent(inspect.getsource(get_pd_series))
330+
udf_code = udf_code[udf_code.index("def") :]
331+
bigframes_handler_code = textwrap.dedent(
332+
f"""def bigframes_handler(str_arg):
333+
return {udf_name}({get_pd_series.__name__}(str_arg))"""
334+
)
335+
else:
336+
udf_code = ""
337+
bigframes_handler_code = textwrap.dedent(
338+
f"""def bigframes_handler(*args):
339+
return {udf_name}(*args)"""
340+
)
341+
342+
udf_code_block = textwrap.dedent(
343+
f"{udf_code}\n{func_code}\n{bigframes_handler_code}"
344+
)
345+
346+
return udf_code_block

0 commit comments

Comments
 (0)