|
| 1 | +# Licensed to the Apache Software Foundation (ASF) under one |
| 2 | +# or more contributor license agreements. See the NOTICE file |
| 3 | +# distributed with this work for additional information |
| 4 | +# regarding copyright ownership. The ASF licenses this file |
| 5 | +# to you under the Apache License, Version 2.0 (the |
| 6 | +# "License"); you may not use this file except in compliance |
| 7 | +# with the License. You may obtain a copy of the License at |
| 8 | +# |
| 9 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 10 | +# |
| 11 | +# Unless required by applicable law or agreed to in writing, |
| 12 | +# software distributed under the License is distributed on an |
| 13 | +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| 14 | +# KIND, either express or implied. See the License for the |
| 15 | +# specific language governing permissions and limitations |
| 16 | +# under the License. |
| 17 | + |
| 18 | +import inspect |
| 19 | +from typing import Any, Literal, Optional, List, Union |
| 20 | + |
| 21 | +from sedonadb._lib import sedona_scalar_udf |
| 22 | +from sedonadb.utility import sedona # noqa: F401 |
| 23 | + |
| 24 | + |
| 25 | +class TypeMatcher(str): |
| 26 | + """Helper class to mark type matchers that can be used as the `input_types` for |
| 27 | + user-defined functions |
| 28 | +
|
| 29 | + Note that the internal storage of the type matcher (currently a string) is |
| 30 | + arbitrary and may change in a future release. Use the constants provided by |
| 31 | + the `udf` module. |
| 32 | + """ |
| 33 | + |
| 34 | + pass |
| 35 | + |
| 36 | + |
| 37 | +def arrow_udf( |
| 38 | + return_type: Any, |
| 39 | + input_types: List[Union[TypeMatcher, Any]] = None, |
| 40 | + volatility: Literal["immutable", "stable", "volatile"] = "immutable", |
| 41 | + name: Optional[str] = None, |
| 42 | +): |
| 43 | + """Generic Arrow-based user-defined scalar function decorator |
| 44 | +
|
| 45 | + This decorator may be used to annotate a function that accepts arguments as |
| 46 | + Arrow array wrappers implementing the |
| 47 | + [Arrow PyCapsule Interface](https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html). |
| 48 | + The annotated function must return a value of a consistent length of the |
| 49 | + appropriate type. |
| 50 | +
|
| 51 | + !!! warning |
| 52 | + SedonaDB will call the provided function from multiple threads. Attempts |
| 53 | + to modify shared state from the body of the function may crash or cause |
| 54 | + unusual behaviour. |
| 55 | +
|
| 56 | + SedonaDB Python UDFs are experimental and this interface may change based on |
| 57 | + user feedback. |
| 58 | +
|
| 59 | + Args: |
| 60 | + return_type: One of |
| 61 | + - A data type (e.g., pyarrow.DataType, arro3.core.DataType, nanoarrow.Schema) |
| 62 | + if this function returns the same type regardless of its inputs. |
| 63 | + - A function of `arg_types` (list of data types) and `scalar_args` (list of |
| 64 | + optional scalars) that returns a data type. This function is also |
| 65 | + responsible for returning `None` if this function does not apply to the |
| 66 | + input types. |
| 67 | + input_types: One of |
| 68 | + - A list where each member is a data type or a `TypeMatcher`. The |
| 69 | + `udf.GEOMETRY` and `udf.GEOGRAPHY` type matchers are the most useful |
| 70 | + because otherwise the function will only match spatial data types whose |
| 71 | + coordinate reference system (CRS) also matches (i.e., based on simple |
| 72 | + equality). Using these type matchers will also ensure input CRS consistency |
| 73 | + and will automatically propagate input CRSes into the output. |
| 74 | + - `None`, indicating that this function can accept any number of arguments |
| 75 | + of any type. Usually this is paired with a functional `return_type` that |
| 76 | + dynamically computes a return type or returns `None` if the number or |
| 77 | + types of arguments do not match. |
| 78 | + volatility: Use "immutable" for functions whose output is always consistent |
| 79 | + for the same inputs (even between queries); use "stable" for functions |
| 80 | + whose output is always consistent for the same inputs but only within |
| 81 | + the same query, and use "volatile" for functions that generate random |
| 82 | + or otherwise non-deterministic output. |
| 83 | + name: An optional name for the UDF. If not given, it will be derived from |
| 84 | + the name of the provided function. |
| 85 | +
|
| 86 | + Examples: |
| 87 | +
|
| 88 | + >>> import pyarrow as pa |
| 89 | + >>> from sedonadb import udf |
| 90 | + >>> sd = sedona.db.connect() |
| 91 | +
|
| 92 | + The simplest scalar UDF only specifies return types. This implies that |
| 93 | + the function can handle input of any type. |
| 94 | +
|
| 95 | + >>> @udf.arrow_udf(pa.string()) |
| 96 | + ... def some_udf(arg0, arg1): |
| 97 | + ... arg0, arg1 = ( |
| 98 | + ... pa.array(arg0.to_array()).to_pylist(), |
| 99 | + ... pa.array(arg1.to_array()).to_pylist(), |
| 100 | + ... ) |
| 101 | + ... return pa.array( |
| 102 | + ... (f"{item0} / {item1}" for item0, item1 in zip(arg0, arg1)), |
| 103 | + ... pa.string(), |
| 104 | + ... ) |
| 105 | + ... |
| 106 | + >>> sd.register_udf(some_udf) |
| 107 | + >>> sd.sql("SELECT some_udf(123, 'abc') as col").show() |
| 108 | + ┌───────────┐ |
| 109 | + │ col │ |
| 110 | + │ utf8 │ |
| 111 | + ╞═══════════╡ |
| 112 | + │ 123 / abc │ |
| 113 | + └───────────┘ |
| 114 | +
|
| 115 | + Use the `TypeMatcher` constants where possible to specify input. |
| 116 | + This ensures that the function can handle the usual range of input |
| 117 | + types that might exist for a given input. |
| 118 | +
|
| 119 | + >>> @udf.arrow_udf(pa.int64(), [udf.STRING]) |
| 120 | + ... def char_count(arg0): |
| 121 | + ... arg0 = pa.array(arg0.to_array()) |
| 122 | + ... |
| 123 | + ... return pa.array( |
| 124 | + ... (len(item) for item in arg0.to_pylist()), |
| 125 | + ... pa.int64() |
| 126 | + ... ) |
| 127 | + ... |
| 128 | + >>> sd.register_udf(char_count) |
| 129 | + >>> sd.sql("SELECT char_count('abcde') as col").show() |
| 130 | + ┌───────┐ |
| 131 | + │ col │ |
| 132 | + │ int64 │ |
| 133 | + ╞═══════╡ |
| 134 | + │ 5 │ |
| 135 | + └───────┘ |
| 136 | +
|
| 137 | + In this case, the type matcher ensures we can also use the function |
| 138 | + for string view input which is the usual type SedonaDB emits when |
| 139 | + reading Parquet files. |
| 140 | +
|
| 141 | + >>> sd.sql("SELECT char_count(arrow_cast('abcde', 'Utf8View')) as col").show() |
| 142 | + ┌───────┐ |
| 143 | + │ col │ |
| 144 | + │ int64 │ |
| 145 | + ╞═══════╡ |
| 146 | + │ 5 │ |
| 147 | + └───────┘ |
| 148 | +
|
| 149 | + Geometry UDFs are best written using Shapely because pyproj (including its use |
| 150 | + in GeoPandas) is not thread safe and can crash when attempting to look up |
| 151 | + CRSes when importing an Arrow array. The UDF framework supports returning |
| 152 | + geometry storage to make this possible. Coordinate reference system metadata |
| 153 | + is propagated automatically from the input. |
| 154 | +
|
| 155 | + >>> import shapely |
| 156 | + >>> import geoarrow.pyarrow as ga |
| 157 | + >>> @udf.arrow_udf(ga.wkb(), [udf.GEOMETRY, udf.NUMERIC]) |
| 158 | + ... def shapely_udf(geom, distance): |
| 159 | + ... geom_wkb = pa.array(geom.storage.to_array()) |
| 160 | + ... distance = pa.array(distance.to_array()) |
| 161 | + ... geom = shapely.from_wkb(geom_wkb) |
| 162 | + ... result_shapely = shapely.buffer(geom, distance) |
| 163 | + ... return pa.array(shapely.to_wkb(result_shapely)) |
| 164 | + ... |
| 165 | + >>> |
| 166 | + >>> sd.register_udf(shapely_udf) |
| 167 | + >>> sd.sql("SELECT ST_SRID(shapely_udf(ST_Point(0, 0), 2.0)) as col").show() |
| 168 | + ┌────────┐ |
| 169 | + │ col │ |
| 170 | + │ uint32 │ |
| 171 | + ╞════════╡ |
| 172 | + │ 0 │ |
| 173 | + └────────┘ |
| 174 | +
|
| 175 | + >>> sd.sql("SELECT ST_SRID(shapely_udf(ST_SetSRID(ST_Point(0, 0), 3857), 2.0)) as col").show() |
| 176 | + ┌────────┐ |
| 177 | + │ col │ |
| 178 | + │ uint32 │ |
| 179 | + ╞════════╡ |
| 180 | + │ 3857 │ |
| 181 | + └────────┘ |
| 182 | +
|
| 183 | + Annotated functions may also declare keyword arguments `return_type` and/or `num_rows`, |
| 184 | + which will be passed the appropriate value by the UDF framework. This facilitates writing |
| 185 | + generic UDFs and/or UDFs with no arguments. |
| 186 | +
|
| 187 | + >>> import numpy as np |
| 188 | + >>> def random_impl(return_type, num_rows): |
| 189 | + ... pa_type = pa.field(return_type).type |
| 190 | + ... return pa.array(np.random.random(num_rows), pa_type) |
| 191 | + ... |
| 192 | + >>> @udf.arrow_udf(pa.float32(), []) |
| 193 | + ... def random_f32(*, return_type=None, num_rows=None): |
| 194 | + ... return random_impl(return_type, num_rows) |
| 195 | + ... |
| 196 | + >>> @udf.arrow_udf(pa.float64(), []) |
| 197 | + ... def random_f64(*, return_type=None, num_rows=None): |
| 198 | + ... return random_impl(return_type, num_rows) |
| 199 | + ... |
| 200 | + >>> np.random.seed(487) |
| 201 | + >>> sd.register_udf(random_f32) |
| 202 | + >>> sd.register_udf(random_f64) |
| 203 | + >>> sd.sql("SELECT random_f32() AS f32, random_f64() as f64;").show() |
| 204 | + ┌────────────┬─────────────────────┐ |
| 205 | + │ f32 ┆ f64 │ |
| 206 | + │ float32 ┆ float64 │ |
| 207 | + ╞════════════╪═════════════════════╡ |
| 208 | + │ 0.35385555 ┆ 0.24793247139474195 │ |
| 209 | + └────────────┴─────────────────────┘ |
| 210 | +
|
| 211 | + """ |
| 212 | + |
| 213 | + def decorator(func): |
| 214 | + kwarg_names = _callable_kwarg_only_names(func) |
| 215 | + if "return_type" in kwarg_names and "num_rows" in kwarg_names: |
| 216 | + |
| 217 | + def func_wrapper(args, return_type, num_rows): |
| 218 | + return func(*args, return_type=return_type, num_rows=num_rows) |
| 219 | + elif "return_type" in kwarg_names: |
| 220 | + |
| 221 | + def func_wrapper(args, return_type, num_rows): |
| 222 | + return func(*args, return_type=return_type) |
| 223 | + elif "num_rows" in kwarg_names: |
| 224 | + |
| 225 | + def func_wrapper(args, return_type, num_rows): |
| 226 | + return func(*args, num_rows=num_rows) |
| 227 | + else: |
| 228 | + |
| 229 | + def func_wrapper(args, return_type, num_rows): |
| 230 | + return func(*args) |
| 231 | + |
| 232 | + name_arg = func.__name__ if name is None and hasattr(func, "__name__") else name |
| 233 | + return ScalarUdfImpl( |
| 234 | + func_wrapper, return_type, input_types, volatility, name_arg |
| 235 | + ) |
| 236 | + |
| 237 | + return decorator |
| 238 | + |
| 239 | + |
| 240 | +BINARY: TypeMatcher = "binary" |
| 241 | +"""Match any binary argument (i.e., binary, binary view, large binary, |
| 242 | +fixed-size binary)""" |
| 243 | + |
| 244 | +BOOLEAN: TypeMatcher = "boolean" |
| 245 | +"""Match a boolean argument""" |
| 246 | + |
| 247 | +GEOGRAPHY: TypeMatcher = "geography" |
| 248 | +"""Match a geography argument""" |
| 249 | + |
| 250 | +GEOMETRY: TypeMatcher = "geometry" |
| 251 | +"""Match a geometry argument""" |
| 252 | + |
| 253 | +NUMERIC: TypeMatcher = "numeric" |
| 254 | +"""Match any numeric argument""" |
| 255 | + |
| 256 | +STRING: TypeMatcher = "string" |
| 257 | +"""Match any string argument (i.e., string, string view, large string)""" |
| 258 | + |
| 259 | + |
| 260 | +class ScalarUdfImpl: |
| 261 | + """Scalar user-defined function wrapper |
| 262 | +
|
| 263 | + This class is a wrapper class used as the return value for user-defined |
| 264 | + function constructors. This wrapper allows the UDF to be registered with |
| 265 | + a SedonaDB context or any context that accepts DataFusion Python |
| 266 | + Scalar UDFs. This object is not intended to be used to call a UDF. |
| 267 | + """ |
| 268 | + |
| 269 | + def __init__( |
| 270 | + self, |
| 271 | + invoke_batch, |
| 272 | + return_type, |
| 273 | + input_types=None, |
| 274 | + volatility: Literal["immutable", "stable", "volatile"] = "immutable", |
| 275 | + name: Optional[str] = None, |
| 276 | + ): |
| 277 | + # If the input_types are None, the return_type must be callable when passed |
| 278 | + # to the internals. In the Python API we allow a data type as the return type |
| 279 | + # to the argument easier to understand, which means we may have to wrap |
| 280 | + # it in a callable here. |
| 281 | + if input_types is None and not callable(return_type): |
| 282 | + |
| 283 | + def return_type_impl(*args, **kwargs): |
| 284 | + return return_type |
| 285 | + |
| 286 | + self._return_type = return_type_impl |
| 287 | + else: |
| 288 | + self._return_type = return_type |
| 289 | + |
| 290 | + self._invoke_batch = invoke_batch |
| 291 | + self._input_types = input_types |
| 292 | + if name is None and hasattr(invoke_batch, "__name__"): |
| 293 | + self._name = invoke_batch.__name__ |
| 294 | + else: |
| 295 | + self._name = name |
| 296 | + |
| 297 | + self._volatility = volatility |
| 298 | + |
| 299 | + def __sedona_internal_udf__(self): |
| 300 | + return sedona_scalar_udf( |
| 301 | + self._invoke_batch, |
| 302 | + self._return_type, |
| 303 | + self._input_types, |
| 304 | + self._volatility, |
| 305 | + self._name, |
| 306 | + ) |
| 307 | + |
| 308 | + def __datafusion_scalar_udf__(self): |
| 309 | + return self.__sedona_internal_udf__().__datafusion_scalar_udf__() |
| 310 | + |
| 311 | + |
| 312 | +def _callable_kwarg_only_names(f): |
| 313 | + sig = inspect.signature(f) |
| 314 | + return [ |
| 315 | + k for k, p in sig.parameters.items() if p.kind == inspect.Parameter.KEYWORD_ONLY |
| 316 | + ] |
0 commit comments