Skip to content

Commit 93353c3

Browse files
authored
refactor: use unit testable real python functions for most templated remote_functions code (#751)
* refactor: use unit testable real python functions for most templated remote_functions code * revert changes and add unit tests * Update tests/unit/functions/test_remote_function_template.py * mypy failure * Update bigframes/functions/remote_function_template.py
1 parent fb8cf8f commit 93353c3

File tree

4 files changed

+284
-171
lines changed

4 files changed

+284
-171
lines changed

bigframes/functions/remote_function.py

Lines changed: 8 additions & 171 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
import string
2525
import sys
2626
import tempfile
27-
import textwrap
2827
from typing import (
2928
Any,
3029
cast,
@@ -61,6 +60,7 @@
6160
from bigframes import clients
6261
import bigframes.constants as constants
6362
import bigframes.dtypes
63+
import bigframes.functions.remote_function_template
6464

6565
logger = logging.getLogger(__name__)
6666

@@ -258,171 +258,8 @@ def get_cloud_function_endpoint(self, name):
258258
pass
259259
return None
260260

261-
def generate_udf_code(self, def_, dir):
262-
"""Generate serialized bytecode using cloudpickle given a udf."""
263-
udf_code_file_name = "udf.py"
264-
udf_bytecode_file_name = "udf.cloudpickle"
265-
266-
# original code, only for debugging purpose
267-
udf_code = textwrap.dedent(inspect.getsource(def_))
268-
udf_code_file_path = os.path.join(dir, udf_code_file_name)
269-
with open(udf_code_file_path, "w") as f:
270-
f.write(udf_code)
271-
272-
# serialized bytecode
273-
udf_bytecode_file_path = os.path.join(dir, udf_bytecode_file_name)
274-
with open(udf_bytecode_file_path, "wb") as f:
275-
cloudpickle.dump(def_, f, protocol=_pickle_protocol_version)
276-
277-
return udf_code_file_name, udf_bytecode_file_name
278-
279-
def generate_cloud_function_main_code(self, def_, dir, is_row_processor=False):
280-
"""Get main.py code for the cloud function for the given user defined function."""
281-
282-
# Pickle the udf with all its dependencies
283-
udf_code_file, udf_bytecode_file = self.generate_udf_code(def_, dir)
284-
handler_func_name = "udf_http"
285-
286-
# We want to build a cloud function that works for BQ remote functions,
287-
# where we receive `calls` in json which is a batch of rows from BQ SQL.
288-
# The number and the order of values in each row is expected to exactly
289-
# match to the number and order of arguments in the udf , e.g. if the udf is
290-
# def foo(x: int, y: str):
291-
# ...
292-
# then the http request body could look like
293-
# {
294-
# ...
295-
# "calls" : [
296-
# [123, "hello"],
297-
# [456, "world"]
298-
# ]
299-
# ...
300-
# }
301-
# https://cloud.google.com/bigquery/docs/reference/standard-sql/remote-functions#input_format
302-
code = """\
303-
import cloudpickle
304-
import functions_framework
305-
from flask import jsonify
306-
import json
307-
"""
308-
if is_row_processor:
309-
code += """\
310-
import ast
311-
import math
312-
import pandas as pd
313-
314-
def get_pd_series(row):
315-
row_json = json.loads(row)
316-
col_names = row_json["names"]
317-
col_types = row_json["types"]
318-
col_values = row_json["values"]
319-
index_length = row_json["indexlength"]
320-
dtype = row_json["dtype"]
321-
322-
# At this point we are assuming that col_names, col_types and col_values are
323-
# arrays of the same length, representing column names, types and values for
324-
# one row of data
325-
326-
# column names are not necessarily strings
327-
# they are serialized as repr(name) at source
328-
evaluated_col_names = []
329-
for col_name in col_names:
330-
try:
331-
col_name = ast.literal_eval(col_name)
332-
except Exception as ex:
333-
raise NameError(f"Failed to evaluate column name from '{col_name}': {ex}")
334-
evaluated_col_names.append(col_name)
335-
col_names = evaluated_col_names
336-
337-
# Supported converters for pandas to python types
338-
value_converters = {
339-
"boolean": lambda val: val == "true",
340-
"Int64": int,
341-
"Float64": float,
342-
"string": str,
343-
}
344-
345-
def convert_value(value, value_type):
346-
value_converter = value_converters.get(value_type)
347-
if value_converter is None:
348-
raise ValueError(f"Don't know how to handle type '{value_type}'")
349-
if value is None:
350-
return None
351-
return value_converter(value)
352-
353-
index_values = [
354-
pd.Series([convert_value(col_values[i], col_types[i])], dtype=col_types[i])[0]
355-
for i in range(index_length)
356-
]
357-
358-
data_col_names = col_names[index_length:]
359-
data_col_types = col_types[index_length:]
360-
data_col_values = col_values[index_length:]
361-
data_col_values = [
362-
pd.Series([convert_value(a, data_col_types[i])], dtype=data_col_types[i])[0]
363-
for i, a in enumerate(data_col_values)
364-
]
365-
366-
row_index = index_values[0] if len(index_values) == 1 else tuple(index_values)
367-
row_series = pd.Series(data_col_values, index=data_col_names, name=row_index, dtype=dtype)
368-
return row_series
369-
"""
370-
code += f"""\
371-
372-
# original udf code is in {udf_code_file}
373-
# serialized udf code is in {udf_bytecode_file}
374-
with open("{udf_bytecode_file}", "rb") as f:
375-
udf = cloudpickle.load(f)
376-
377-
def {handler_func_name}(request):
378-
try:
379-
request_json = request.get_json(silent=True)
380-
calls = request_json["calls"]
381-
replies = []
382-
for call in calls:
383-
"""
384-
385-
if is_row_processor:
386-
code += """\
387-
reply = udf(get_pd_series(call[0]))
388-
if isinstance(reply, float) and (math.isnan(reply) or math.isinf(reply)):
389-
# json serialization of the special float values (nan, inf, -inf)
390-
# is not in strict compliance of the JSON specification
391-
# https://docs.python.org/3/library/json.html#basic-usage.
392-
# Let's convert them to a quoted string representation ("NaN",
393-
# "Infinity", "-Infinity" respectively) which is handled by
394-
# BigQuery
395-
reply = json.dumps(reply)
396-
elif pd.isna(reply):
397-
# Pandas N/A values are not json serializable, so use a python
398-
# equivalent instead
399-
reply = None
400-
elif hasattr(reply, "item"):
401-
# Numpy types are not json serializable, so use its Python
402-
# value instead
403-
reply = reply.item()
404-
"""
405-
else:
406-
code += """\
407-
reply = udf(*call)
408-
"""
409-
code += """\
410-
replies.append(reply)
411-
return_json = json.dumps({"replies" : replies})
412-
return return_json
413-
except Exception as e:
414-
return jsonify( { "errorMessage": str(e) } ), 400
415-
"""
416-
417-
main_py = os.path.join(dir, "main.py")
418-
with open(main_py, "w") as f:
419-
f.write(code)
420-
logger.debug(f"Wrote {os.path.abspath(main_py)}:\n{open(main_py).read()}")
421-
422-
return handler_func_name
423-
424261
def generate_cloud_function_code(
425-
self, def_, dir, package_requirements=None, is_row_processor=False
262+
self, def_, directory, package_requirements=None, is_row_processor=False
426263
):
427264
"""Generate the cloud function code for a given user defined function."""
428265

@@ -435,13 +272,13 @@ def generate_cloud_function_code(
435272
if package_requirements:
436273
requirements.extend(package_requirements)
437274
requirements = sorted(requirements)
438-
requirements_txt = os.path.join(dir, "requirements.txt")
275+
requirements_txt = os.path.join(directory, "requirements.txt")
439276
with open(requirements_txt, "w") as f:
440277
f.write("\n".join(requirements))
441278

442279
# main.py
443-
entry_point = self.generate_cloud_function_main_code(
444-
def_, dir, is_row_processor
280+
entry_point = bigframes.functions.remote_function_template.generate_cloud_function_main_code(
281+
def_, directory, is_row_processor
445282
)
446283
return entry_point
447284

@@ -458,11 +295,11 @@ def create_cloud_function(
458295
"""Create a cloud function from the given user defined function."""
459296

460297
# Build and deploy folder structure containing cloud function
461-
with tempfile.TemporaryDirectory() as dir:
298+
with tempfile.TemporaryDirectory() as directory:
462299
entry_point = self.generate_cloud_function_code(
463-
def_, dir, package_requirements, is_row_processor
300+
def_, directory, package_requirements, is_row_processor
464301
)
465-
archive_path = shutil.make_archive(dir, "zip", dir)
302+
archive_path = shutil.make_archive(directory, "zip", directory)
466303

467304
# We are creating cloud function source code from the currently running
468305
# python version. Use the same version to deploy. This is necessary

0 commit comments

Comments
 (0)