Skip to content

Commit f984381

Browse files
refactor: Refactor udf definitions (#1814)
1 parent 1e8a2f1 commit f984381

File tree

23 files changed

+791
-921
lines changed

23 files changed

+791
-921
lines changed

bigframes/core/compile/ibis_types.py

Lines changed: 0 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -13,20 +13,14 @@
1313
# limitations under the License.
1414
from __future__ import annotations
1515

16-
import typing
1716
from typing import cast, Dict, Iterable, Optional, Tuple, Union
1817

1918
import bigframes_vendored.constants as constants
2019
import bigframes_vendored.ibis
21-
import bigframes_vendored.ibis.backends.bigquery.datatypes as third_party_ibis_bqtypes
2220
import bigframes_vendored.ibis.expr.datatypes as ibis_dtypes
23-
from bigframes_vendored.ibis.expr.datatypes.core import (
24-
dtype as python_type_to_ibis_type,
25-
)
2621
import bigframes_vendored.ibis.expr.types as ibis_types
2722
import db_dtypes # type: ignore
2823
import geopandas as gpd # type: ignore
29-
import google.cloud.bigquery as bigquery
3024
import pandas as pd
3125
import pyarrow as pa
3226

@@ -439,45 +433,3 @@ def literal_to_ibis_scalar(
439433
)
440434

441435
return scalar_expr
442-
443-
444-
class UnsupportedTypeError(ValueError):
445-
def __init__(self, type_, supported_types):
446-
self.type = type_
447-
self.supported_types = supported_types
448-
super().__init__(
449-
f"'{type_}' is not one of the supported types {supported_types}"
450-
)
451-
452-
453-
def ibis_type_from_python_type(t: type) -> ibis_dtypes.DataType:
454-
if t not in bigframes.dtypes.RF_SUPPORTED_IO_PYTHON_TYPES:
455-
raise UnsupportedTypeError(t, bigframes.dtypes.RF_SUPPORTED_IO_PYTHON_TYPES)
456-
return python_type_to_ibis_type(t)
457-
458-
459-
def ibis_array_output_type_from_python_type(t: type) -> ibis_dtypes.DataType:
460-
array_of = typing.get_args(t)[0]
461-
if array_of not in bigframes.dtypes.RF_SUPPORTED_ARRAY_OUTPUT_PYTHON_TYPES:
462-
raise UnsupportedTypeError(
463-
array_of, bigframes.dtypes.RF_SUPPORTED_ARRAY_OUTPUT_PYTHON_TYPES
464-
)
465-
return python_type_to_ibis_type(t)
466-
467-
468-
def ibis_type_from_bigquery_type(
469-
type_: bigquery.StandardSqlDataType,
470-
) -> ibis_dtypes.DataType:
471-
"""Convert bq type to ibis. Only to be used for remote functions, does not handle all types."""
472-
if type_.type_kind not in bigframes.dtypes.RF_SUPPORTED_IO_BIGQUERY_TYPEKINDS:
473-
raise UnsupportedTypeError(
474-
type_.type_kind, bigframes.dtypes.RF_SUPPORTED_IO_BIGQUERY_TYPEKINDS
475-
)
476-
elif type_.type_kind == "ARRAY":
477-
return ibis_dtypes.Array(
478-
value_type=ibis_type_from_bigquery_type(
479-
typing.cast(bigquery.StandardSqlDataType, type_.array_element_type)
480-
)
481-
)
482-
else:
483-
return third_party_ibis_bqtypes.BigQueryType.to_ibis(type_.type_kind)

bigframes/core/compile/scalar_op_compiler.py

Lines changed: 49 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
import functools
1818
import typing
1919

20-
import bigframes_vendored.constants as constants
2120
import bigframes_vendored.ibis.expr.api as ibis_api
2221
import bigframes_vendored.ibis.expr.datatypes as ibis_dtypes
2322
import bigframes_vendored.ibis.expr.operations.generic as ibis_generic
@@ -30,6 +29,7 @@
3029
import bigframes.core.compile.default_ordering
3130
import bigframes.core.compile.ibis_types
3231
import bigframes.core.expression as ex
32+
import bigframes.dtypes
3333
import bigframes.operations as ops
3434

3535
_ZERO = typing.cast(ibis_types.NumericValue, ibis_types.literal(0))
@@ -1284,17 +1284,58 @@ def timedelta_floor_op_impl(x: ibis_types.NumericValue):
12841284

12851285
@scalar_op_compiler.register_unary_op(ops.RemoteFunctionOp, pass_op=True)
12861286
def remote_function_op_impl(x: ibis_types.Value, op: ops.RemoteFunctionOp):
1287-
ibis_node = getattr(op.func, "ibis_node", None)
1288-
if ibis_node is None:
1289-
raise TypeError(
1290-
f"only a bigframes remote function is supported as a callable. {constants.FEEDBACK_LINK}"
1291-
)
1292-
x_transformed = ibis_node(x)
1287+
udf_sig = op.function_def.signature
1288+
ibis_py_sig = (udf_sig.py_input_types, udf_sig.py_output_type)
1289+
1290+
@ibis_udf.scalar.builtin(
1291+
name=str(op.function_def.routine_ref), signature=ibis_py_sig
1292+
)
1293+
def udf(input):
1294+
...
1295+
1296+
x_transformed = udf(x)
12931297
if not op.apply_on_null:
1294-
x_transformed = ibis_api.case().when(x.isnull(), x).else_(x_transformed).end()
1298+
return ibis_api.case().when(x.isnull(), x).else_(x_transformed).end()
12951299
return x_transformed
12961300

12971301

1302+
@scalar_op_compiler.register_binary_op(ops.BinaryRemoteFunctionOp, pass_op=True)
1303+
def binary_remote_function_op_impl(
1304+
x: ibis_types.Value, y: ibis_types.Value, op: ops.BinaryRemoteFunctionOp
1305+
):
1306+
udf_sig = op.function_def.signature
1307+
ibis_py_sig = (udf_sig.py_input_types, udf_sig.py_output_type)
1308+
1309+
@ibis_udf.scalar.builtin(
1310+
name=str(op.function_def.routine_ref), signature=ibis_py_sig
1311+
)
1312+
def udf(input1, input2):
1313+
...
1314+
1315+
x_transformed = udf(x, y)
1316+
return x_transformed
1317+
1318+
1319+
@scalar_op_compiler.register_nary_op(ops.NaryRemoteFunctionOp, pass_op=True)
1320+
def nary_remote_function_op_impl(
1321+
*operands: ibis_types.Value, op: ops.NaryRemoteFunctionOp
1322+
):
1323+
udf_sig = op.function_def.signature
1324+
ibis_py_sig = (udf_sig.py_input_types, udf_sig.py_output_type)
1325+
arg_names = tuple(arg.name for arg in udf_sig.input_types)
1326+
1327+
@ibis_udf.scalar.builtin(
1328+
name=str(op.function_def.routine_ref),
1329+
signature=ibis_py_sig,
1330+
param_name_overrides=arg_names,
1331+
)
1332+
def udf(*inputs):
1333+
...
1334+
1335+
result = udf(*operands)
1336+
return result
1337+
1338+
12981339
@scalar_op_compiler.register_unary_op(ops.MapOp, pass_op=True)
12991340
def map_op_impl(x: ibis_types.Value, op: ops.MapOp):
13001341
case = ibis_api.case()
@@ -1931,19 +1972,6 @@ def manhattan_distance_impl(
19311972
return vector_distance(vector1, vector2, "MANHATTAN")
19321973

19331974

1934-
@scalar_op_compiler.register_binary_op(ops.BinaryRemoteFunctionOp, pass_op=True)
1935-
def binary_remote_function_op_impl(
1936-
x: ibis_types.Value, y: ibis_types.Value, op: ops.BinaryRemoteFunctionOp
1937-
):
1938-
ibis_node = getattr(op.func, "ibis_node", None)
1939-
if ibis_node is None:
1940-
raise TypeError(
1941-
f"only a bigframes remote function is supported as a callable. {constants.FEEDBACK_LINK}"
1942-
)
1943-
x_transformed = ibis_node(x, y)
1944-
return x_transformed
1945-
1946-
19471975
# Blob Ops
19481976
@scalar_op_compiler.register_binary_op(ops.obj_make_ref_op)
19491977
def obj_make_ref_op(x: ibis_types.Value, y: ibis_types.Value):
@@ -2005,19 +2033,6 @@ def case_when_op(*cases_and_outputs: ibis_types.Value) -> ibis_types.Value:
20052033
return case_val.end() # type: ignore
20062034

20072035

2008-
@scalar_op_compiler.register_nary_op(ops.NaryRemoteFunctionOp, pass_op=True)
2009-
def nary_remote_function_op_impl(
2010-
*operands: ibis_types.Value, op: ops.NaryRemoteFunctionOp
2011-
):
2012-
ibis_node = getattr(op.func, "ibis_node", None)
2013-
if ibis_node is None:
2014-
raise TypeError(
2015-
f"only a bigframes remote function is supported as a callable. {constants.FEEDBACK_LINK}"
2016-
)
2017-
result = ibis_node(*operands)
2018-
return result
2019-
2020-
20212036
@scalar_op_compiler.register_nary_op(ops.SqlScalarOp, pass_op=True)
20222037
def sql_scalar_op_impl(*operands: ibis_types.Value, op: ops.SqlScalarOp):
20232038
return ibis_generic.SqlScalar(

bigframes/dataframe.py

Lines changed: 17 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@
7474
import bigframes.dtypes
7575
import bigframes.exceptions as bfe
7676
import bigframes.formatting_helpers as formatter
77+
import bigframes.functions
7778
import bigframes.operations as ops
7879
import bigframes.operations.aggregations as agg_ops
7980
import bigframes.operations.ai
@@ -4470,15 +4471,17 @@ def _prepare_export(
44704471
return array_value, id_overrides
44714472

44724473
def map(self, func, na_action: Optional[str] = None) -> DataFrame:
4473-
if not callable(func):
4474+
if not isinstance(func, bigframes.functions.BigqueryCallableRoutine):
44744475
raise TypeError("the first argument must be callable")
44754476

44764477
if na_action not in {None, "ignore"}:
44774478
raise ValueError(f"na_action={na_action} not supported")
44784479

44794480
# TODO(shobs): Support **kwargs
44804481
return self._apply_unary_op(
4481-
ops.RemoteFunctionOp(func=func, apply_on_null=(na_action is None))
4482+
ops.RemoteFunctionOp(
4483+
function_def=func.udf_def, apply_on_null=(na_action is None)
4484+
)
44824485
)
44834486

44844487
def apply(self, func, *, axis=0, args: typing.Tuple = (), **kwargs):
@@ -4492,13 +4495,18 @@ def apply(self, func, *, axis=0, args: typing.Tuple = (), **kwargs):
44924495
)
44934496
warnings.warn(msg, category=bfe.FunctionAxisOnePreviewWarning)
44944497

4495-
if not hasattr(func, "bigframes_bigquery_function"):
4498+
if not isinstance(
4499+
func,
4500+
(
4501+
bigframes.functions.BigqueryCallableRoutine,
4502+
bigframes.functions.BigqueryCallableRowRoutine,
4503+
),
4504+
):
44964505
raise ValueError(
44974506
"For axis=1 a BigFrames BigQuery function must be used."
44984507
)
44994508

4500-
is_row_processor = getattr(func, "is_row_processor")
4501-
if is_row_processor:
4509+
if func.is_row_processor:
45024510
# Early check whether the dataframe dtypes are currently supported
45034511
# in the bigquery function
45044512
# NOTE: Keep in sync with the value converters used in the gcf code
@@ -4552,7 +4560,7 @@ def apply(self, func, *, axis=0, args: typing.Tuple = (), **kwargs):
45524560

45534561
# Apply the function
45544562
result_series = rows_as_json_series._apply_unary_op(
4555-
ops.RemoteFunctionOp(func=func, apply_on_null=True)
4563+
ops.RemoteFunctionOp(function_def=func.udf_def, apply_on_null=True)
45564564
)
45574565
else:
45584566
# This is a special case where we are providing not-pandas-like
@@ -4567,7 +4575,7 @@ def apply(self, func, *, axis=0, args: typing.Tuple = (), **kwargs):
45674575
# compatible with the data types of the input params
45684576
# 3. The order of the columns in the dataframe must correspond
45694577
# to the order of the input params in the function
4570-
udf_input_dtypes = getattr(func, "input_dtypes")
4578+
udf_input_dtypes = func.udf_def.signature.bf_input_types
45714579
if len(udf_input_dtypes) != len(self.columns):
45724580
raise ValueError(
45734581
f"BigFrames BigQuery function takes {len(udf_input_dtypes)}"
@@ -4581,25 +4589,11 @@ def apply(self, func, *, axis=0, args: typing.Tuple = (), **kwargs):
45814589

45824590
series_list = [self[col] for col in self.columns]
45834591
result_series = series_list[0]._apply_nary_op(
4584-
ops.NaryRemoteFunctionOp(func=func), series_list[1:]
4592+
ops.NaryRemoteFunctionOp(function_def=func.udf_def), series_list[1:]
45854593
)
45864594
result_series.name = None
45874595

4588-
# If the result type is string but the function output is intended
4589-
# to be an array, reconstruct the array from the string assuming it
4590-
# is a json serialized form of the array.
4591-
if bigframes.dtypes.is_string_like(
4592-
result_series.dtype
4593-
) and bigframes.dtypes.is_array_like(func.output_dtype):
4594-
import bigframes.bigquery as bbq
4595-
4596-
result_dtype = bigframes.dtypes.arrow_dtype_to_bigframes_dtype(
4597-
func.output_dtype.pyarrow_dtype.value_type
4598-
)
4599-
result_series = bbq.json_extract_string_array(
4600-
result_series, value_dtype=result_dtype
4601-
)
4602-
4596+
result_series = func._post_process_series(result_series)
46034597
return result_series
46044598

46054599
# At this point column-wise or element-wise bigquery function operation will

bigframes/dtypes.py

Lines changed: 0 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -870,32 +870,4 @@ def lcd_type_or_throw(dtype1: Dtype, dtype2: Dtype) -> Dtype:
870870
return result
871871

872872

873-
### Remote functions use only
874-
# TODO: Refactor into remote function module
875-
876-
# Input and output types supported by BigQuery DataFrames remote functions.
877-
# TODO(shobs): Extend the support to all types supported by BQ remote functions
878-
# https://cloud.google.com/bigquery/docs/remote-functions#limitations
879-
RF_SUPPORTED_IO_PYTHON_TYPES = {bool, bytes, float, int, str}
880-
881-
# Support array output types in BigQuery DataFrames remote functions even though
882-
# it is not currently (2024-10-06) supported in BigQuery remote functions.
883-
# https://cloud.google.com/bigquery/docs/remote-functions#limitations
884-
# TODO(b/284515241): remove this special handling when BigQuery remote functions
885-
# support array.
886-
RF_SUPPORTED_ARRAY_OUTPUT_PYTHON_TYPES = {bool, float, int, str}
887-
888-
RF_SUPPORTED_IO_BIGQUERY_TYPEKINDS = {
889-
"BOOLEAN",
890-
"BOOL",
891-
"BYTES",
892-
"FLOAT",
893-
"FLOAT64",
894-
"INT64",
895-
"INTEGER",
896-
"STRING",
897-
"ARRAY",
898-
}
899-
900-
901873
TIMEDELTA_DESCRIPTION_TAG = "#microseconds"

bigframes/functions/__init__.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,12 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
from bigframes.functions.function import (
15+
BigqueryCallableRoutine,
16+
BigqueryCallableRowRoutine,
17+
)
18+
19+
__all__ = [
20+
"BigqueryCallableRoutine",
21+
"BigqueryCallableRowRoutine",
22+
]

0 commit comments

Comments
 (0)