Skip to content

Commit 90ba5b4

Browse files
paleolimbotCopilot
andauthored
feat(python/sedonadb): Implement Python UDFs (#228)
Co-authored-by: Copilot <[email protected]>
1 parent 958b447 commit 90ba5b4

File tree

15 files changed

+1109
-12
lines changed

15 files changed

+1109
-12
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

docs/reference/python.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,3 +25,5 @@
2525
::: sedonadb.testing
2626

2727
::: sedonadb.dbapi
28+
29+
::: sedonadb.udf

python/sedonadb/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ futures = { workspace = true }
4242
pyo3 = { version = "0.25.1" }
4343
sedona = { path = "../../rust/sedona" }
4444
sedona-adbc = { path = "../../rust/sedona-adbc" }
45+
sedona-expr = { path = "../../rust/sedona-expr" }
4546
sedona-geoparquet = { path = "../../rust/sedona-geoparquet" }
4647
sedona-schema = { path = "../../rust/sedona-schema" }
4748
sedona-proj = { path = "../../c/sedona-proj", default-features = false }

python/sedonadb/pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ dynamic = ["version"]
3333
test = [
3434
"adbc-driver-manager[dbapi]",
3535
"adbc-driver-postgresql",
36+
"datafusion",
3637
"duckdb",
3738
"geoarrow-pyarrow",
3839
"geopandas",

python/sedonadb/python/sedonadb/context.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,40 @@ def sql(self, sql: str) -> DataFrame:
170170
"""
171171
return DataFrame(self._impl, self._impl.sql(sql), self.options)
172172

173+
def register_udf(self, udf: Any):
174+
"""Register a user-defined function
175+
176+
Args:
177+
udf: An object implementing the DataFusion PyCapsule protocol
178+
(i.e., `__datafusion_scalar_udf__`) or a function annotated
179+
with [arrow_udf][sedonadb.udf.arrow_udf].
180+
181+
Examples:
182+
183+
>>> import pyarrow as pa
184+
>>> from sedonadb import udf
185+
>>> sd = sedona.db.connect()
186+
>>> @udf.arrow_udf(pa.int64(), [udf.STRING])
187+
... def char_count(arg0):
188+
... arg0 = pa.array(arg0.to_array())
189+
...
190+
... return pa.array(
191+
... (len(item) for item in arg0.to_pylist()),
192+
... pa.int64()
193+
... )
194+
...
195+
>>> sd.register_udf(char_count)
196+
>>> sd.sql("SELECT char_count('abcde') as col").show()
197+
┌───────┐
198+
│ col │
199+
│ int64 │
200+
╞═══════╡
201+
│ 5 │
202+
└───────┘
203+
204+
"""
205+
self._impl.register_udf(udf)
206+
173207

174208
def connect() -> SedonaContext:
175209
"""Create a new [SedonaContext][sedonadb.context.SedonaContext]"""
Lines changed: 316 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,316 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
import inspect
19+
from typing import Any, Literal, Optional, List, Union
20+
21+
from sedonadb._lib import sedona_scalar_udf
22+
from sedonadb.utility import sedona # noqa: F401
23+
24+
25+
class TypeMatcher(str):
26+
"""Helper class to mark type matchers that can be used as the `input_types` for
27+
user-defined functions
28+
29+
Note that the internal storage of the type matcher (currently a string) is
30+
arbitrary and may change in a future release. Use the constants provided by
31+
the `udf` module.
32+
"""
33+
34+
pass
35+
36+
37+
def arrow_udf(
38+
return_type: Any,
39+
input_types: List[Union[TypeMatcher, Any]] = None,
40+
volatility: Literal["immutable", "stable", "volatile"] = "immutable",
41+
name: Optional[str] = None,
42+
):
43+
"""Generic Arrow-based user-defined scalar function decorator
44+
45+
This decorator may be used to annotate a function that accepts arguments as
46+
Arrow array wrappers implementing the
47+
[Arrow PyCapsule Interface](https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html).
48+
The annotated function must return a value of a consistent length of the
49+
appropriate type.
50+
51+
!!! warning
52+
SedonaDB will call the provided function from multiple threads. Attempts
53+
to modify shared state from the body of the function may crash or cause
54+
unusual behaviour.
55+
56+
SedonaDB Python UDFs are experimental and this interface may change based on
57+
user feedback.
58+
59+
Args:
60+
return_type: One of
61+
- A data type (e.g., pyarrow.DataType, arro3.core.DataType, nanoarrow.Schema)
62+
if this function returns the same type regardless of its inputs.
63+
- A function of `arg_types` (list of data types) and `scalar_args` (list of
64+
optional scalars) that returns a data type. This function is also
65+
responsible for returning `None` if this function does not apply to the
66+
input types.
67+
input_types: One of
68+
- A list where each member is a data type or a `TypeMatcher`. The
69+
`udf.GEOMETRY` and `udf.GEOGRAPHY` type matchers are the most useful
70+
because otherwise the function will only match spatial data types whose
71+
coordinate reference system (CRS) also matches (i.e., based on simple
72+
equality). Using these type matchers will also ensure input CRS consistency
73+
and will automatically propagate input CRSes into the output.
74+
- `None`, indicating that this function can accept any number of arguments
75+
of any type. Usually this is paired with a functional `return_type` that
76+
dynamically computes a return type or returns `None` if the number or
77+
types of arguments do not match.
78+
volatility: Use "immutable" for functions whose output is always consistent
79+
for the same inputs (even between queries); use "stable" for functions
80+
whose output is always consistent for the same inputs but only within
81+
the same query, and use "volatile" for functions that generate random
82+
or otherwise non-deterministic output.
83+
name: An optional name for the UDF. If not given, it will be derived from
84+
the name of the provided function.
85+
86+
Examples:
87+
88+
>>> import pyarrow as pa
89+
>>> from sedonadb import udf
90+
>>> sd = sedona.db.connect()
91+
92+
The simplest scalar UDF only specifies return types. This implies that
93+
the function can handle input of any type.
94+
95+
>>> @udf.arrow_udf(pa.string())
96+
... def some_udf(arg0, arg1):
97+
... arg0, arg1 = (
98+
... pa.array(arg0.to_array()).to_pylist(),
99+
... pa.array(arg1.to_array()).to_pylist(),
100+
... )
101+
... return pa.array(
102+
... (f"{item0} / {item1}" for item0, item1 in zip(arg0, arg1)),
103+
... pa.string(),
104+
... )
105+
...
106+
>>> sd.register_udf(some_udf)
107+
>>> sd.sql("SELECT some_udf(123, 'abc') as col").show()
108+
┌───────────┐
109+
│ col │
110+
│ utf8 │
111+
╞═══════════╡
112+
│ 123 / abc │
113+
└───────────┘
114+
115+
Use the `TypeMatcher` constants where possible to specify input.
116+
This ensures that the function can handle the usual range of input
117+
types that might exist for a given input.
118+
119+
>>> @udf.arrow_udf(pa.int64(), [udf.STRING])
120+
... def char_count(arg0):
121+
... arg0 = pa.array(arg0.to_array())
122+
...
123+
... return pa.array(
124+
... (len(item) for item in arg0.to_pylist()),
125+
... pa.int64()
126+
... )
127+
...
128+
>>> sd.register_udf(char_count)
129+
>>> sd.sql("SELECT char_count('abcde') as col").show()
130+
┌───────┐
131+
│ col │
132+
│ int64 │
133+
╞═══════╡
134+
│ 5 │
135+
└───────┘
136+
137+
In this case, the type matcher ensures we can also use the function
138+
for string view input which is the usual type SedonaDB emits when
139+
reading Parquet files.
140+
141+
>>> sd.sql("SELECT char_count(arrow_cast('abcde', 'Utf8View')) as col").show()
142+
┌───────┐
143+
│ col │
144+
│ int64 │
145+
╞═══════╡
146+
│ 5 │
147+
└───────┘
148+
149+
Geometry UDFs are best written using Shapely because pyproj (including its use
150+
in GeoPandas) is not thread safe and can crash when attempting to look up
151+
CRSes when importing an Arrow array. The UDF framework supports returning
152+
geometry storage to make this possible. Coordinate reference system metadata
153+
is propagated automatically from the input.
154+
155+
>>> import shapely
156+
>>> import geoarrow.pyarrow as ga
157+
>>> @udf.arrow_udf(ga.wkb(), [udf.GEOMETRY, udf.NUMERIC])
158+
... def shapely_udf(geom, distance):
159+
... geom_wkb = pa.array(geom.storage.to_array())
160+
... distance = pa.array(distance.to_array())
161+
... geom = shapely.from_wkb(geom_wkb)
162+
... result_shapely = shapely.buffer(geom, distance)
163+
... return pa.array(shapely.to_wkb(result_shapely))
164+
...
165+
>>>
166+
>>> sd.register_udf(shapely_udf)
167+
>>> sd.sql("SELECT ST_SRID(shapely_udf(ST_Point(0, 0), 2.0)) as col").show()
168+
┌────────┐
169+
│ col │
170+
│ uint32 │
171+
╞════════╡
172+
│ 0 │
173+
└────────┘
174+
175+
>>> sd.sql("SELECT ST_SRID(shapely_udf(ST_SetSRID(ST_Point(0, 0), 3857), 2.0)) as col").show()
176+
┌────────┐
177+
│ col │
178+
│ uint32 │
179+
╞════════╡
180+
│ 3857 │
181+
└────────┘
182+
183+
Annotated functions may also declare keyword arguments `return_type` and/or `num_rows`,
184+
which will be passed the appropriate value by the UDF framework. This facilitates writing
185+
generic UDFs and/or UDFs with no arguments.
186+
187+
>>> import numpy as np
188+
>>> def random_impl(return_type, num_rows):
189+
... pa_type = pa.field(return_type).type
190+
... return pa.array(np.random.random(num_rows), pa_type)
191+
...
192+
>>> @udf.arrow_udf(pa.float32(), [])
193+
... def random_f32(*, return_type=None, num_rows=None):
194+
... return random_impl(return_type, num_rows)
195+
...
196+
>>> @udf.arrow_udf(pa.float64(), [])
197+
... def random_f64(*, return_type=None, num_rows=None):
198+
... return random_impl(return_type, num_rows)
199+
...
200+
>>> np.random.seed(487)
201+
>>> sd.register_udf(random_f32)
202+
>>> sd.register_udf(random_f64)
203+
>>> sd.sql("SELECT random_f32() AS f32, random_f64() as f64;").show()
204+
┌────────────┬─────────────────────┐
205+
│ f32 ┆ f64 │
206+
│ float32 ┆ float64 │
207+
╞════════════╪═════════════════════╡
208+
│ 0.35385555 ┆ 0.24793247139474195 │
209+
└────────────┴─────────────────────┘
210+
211+
"""
212+
213+
def decorator(func):
214+
kwarg_names = _callable_kwarg_only_names(func)
215+
if "return_type" in kwarg_names and "num_rows" in kwarg_names:
216+
217+
def func_wrapper(args, return_type, num_rows):
218+
return func(*args, return_type=return_type, num_rows=num_rows)
219+
elif "return_type" in kwarg_names:
220+
221+
def func_wrapper(args, return_type, num_rows):
222+
return func(*args, return_type=return_type)
223+
elif "num_rows" in kwarg_names:
224+
225+
def func_wrapper(args, return_type, num_rows):
226+
return func(*args, num_rows=num_rows)
227+
else:
228+
229+
def func_wrapper(args, return_type, num_rows):
230+
return func(*args)
231+
232+
name_arg = func.__name__ if name is None and hasattr(func, "__name__") else name
233+
return ScalarUdfImpl(
234+
func_wrapper, return_type, input_types, volatility, name_arg
235+
)
236+
237+
return decorator
238+
239+
240+
BINARY: TypeMatcher = "binary"
241+
"""Match any binary argument (i.e., binary, binary view, large binary,
242+
fixed-size binary)"""
243+
244+
BOOLEAN: TypeMatcher = "boolean"
245+
"""Match a boolean argument"""
246+
247+
GEOGRAPHY: TypeMatcher = "geography"
248+
"""Match a geography argument"""
249+
250+
GEOMETRY: TypeMatcher = "geometry"
251+
"""Match a geometry argument"""
252+
253+
NUMERIC: TypeMatcher = "numeric"
254+
"""Match any numeric argument"""
255+
256+
STRING: TypeMatcher = "string"
257+
"""Match any string argument (i.e., string, string view, large string)"""
258+
259+
260+
class ScalarUdfImpl:
261+
"""Scalar user-defined function wrapper
262+
263+
This class is a wrapper class used as the return value for user-defined
264+
function constructors. This wrapper allows the UDF to be registered with
265+
a SedonaDB context or any context that accepts DataFusion Python
266+
Scalar UDFs. This object is not intended to be used to call a UDF.
267+
"""
268+
269+
def __init__(
270+
self,
271+
invoke_batch,
272+
return_type,
273+
input_types=None,
274+
volatility: Literal["immutable", "stable", "volatile"] = "immutable",
275+
name: Optional[str] = None,
276+
):
277+
# If the input_types are None, the return_type must be callable when passed
278+
# to the internals. In the Python API we allow a data type as the return type
279+
# to the argument easier to understand, which means we may have to wrap
280+
# it in a callable here.
281+
if input_types is None and not callable(return_type):
282+
283+
def return_type_impl(*args, **kwargs):
284+
return return_type
285+
286+
self._return_type = return_type_impl
287+
else:
288+
self._return_type = return_type
289+
290+
self._invoke_batch = invoke_batch
291+
self._input_types = input_types
292+
if name is None and hasattr(invoke_batch, "__name__"):
293+
self._name = invoke_batch.__name__
294+
else:
295+
self._name = name
296+
297+
self._volatility = volatility
298+
299+
def __sedona_internal_udf__(self):
300+
return sedona_scalar_udf(
301+
self._invoke_batch,
302+
self._return_type,
303+
self._input_types,
304+
self._volatility,
305+
self._name,
306+
)
307+
308+
def __datafusion_scalar_udf__(self):
309+
return self.__sedona_internal_udf__().__datafusion_scalar_udf__()
310+
311+
312+
def _callable_kwarg_only_names(f):
313+
sig = inspect.signature(f)
314+
return [
315+
k for k, p in sig.parameters.items() if p.kind == inspect.Parameter.KEYWORD_ONLY
316+
]

0 commit comments

Comments
 (0)