Skip to content

Commit 5654587

Browse files
authored
Merge branch 'apache:main' into 1056/refactor/add-additional-ruff-suggestions
2 parents 914252c + 7c1c08f commit 5654587

File tree

11 files changed

+422
-50
lines changed

11 files changed

+422
-50
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@ dev = [
133133
"maturin>=1.8.1",
134134
"numpy>1.25.0",
135135
"pytest>=7.4.4",
136+
"pytest-asyncio>=0.23.3",
136137
"ruff>=0.9.1",
137138
"toml>=0.10.2",
138139
"pygithub==2.5.0",

python/datafusion/expr.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ class Expr:
193193
:ref:`Expressions` in the online documentation for more information.
194194
"""
195195

196-
def __init__(self, expr: expr_internal.Expr) -> None:
196+
def __init__(self, expr: expr_internal.RawExpr) -> None:
197197
"""This constructor should not be called by the end user."""
198198
self.expr = expr
199199

@@ -383,7 +383,7 @@ def literal(value: Any) -> Expr:
383383
value = pa.scalar(value, type=pa.string_view())
384384
if not isinstance(value, pa.Scalar):
385385
value = pa.scalar(value)
386-
return Expr(expr_internal.Expr.literal(value))
386+
return Expr(expr_internal.RawExpr.literal(value))
387387

388388
@staticmethod
389389
def string_literal(value: str) -> Expr:
@@ -398,13 +398,13 @@ def string_literal(value: str) -> Expr:
398398
"""
399399
if isinstance(value, str):
400400
value = pa.scalar(value, type=pa.string())
401-
return Expr(expr_internal.Expr.literal(value))
401+
return Expr(expr_internal.RawExpr.literal(value))
402402
return Expr.literal(value)
403403

404404
@staticmethod
405405
def column(value: str) -> Expr:
406406
"""Creates a new expression representing a column."""
407-
return Expr(expr_internal.Expr.column(value))
407+
return Expr(expr_internal.RawExpr.column(value))
408408

409409
def alias(self, name: str) -> Expr:
410410
"""Assign a name to the expression."""

python/datafusion/functions.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,7 @@
217217
"random",
218218
"range",
219219
"rank",
220+
"regexp_count",
220221
"regexp_like",
221222
"regexp_match",
222223
"regexp_replace",
@@ -779,6 +780,23 @@ def regexp_replace(
779780
return Expr(f.regexp_replace(string.expr, pattern.expr, replacement.expr, flags))
780781

781782

783+
def regexp_count(
784+
string: Expr, pattern: Expr, start: Expr, flags: Expr | None = None
785+
) -> Expr:
786+
"""Returns the number of matches in a string.
787+
788+
Optional start position (the first position is 1) to search for the regular
789+
expression.
790+
"""
791+
if flags is not None:
792+
flags = flags.expr
793+
if start is not None:
794+
start = start.expr
795+
else:
796+
start = Expr.expr
797+
return Expr(f.regexp_count(string.expr, pattern.expr, start, flags))
798+
799+
782800
def repeat(string: Expr, n: Expr) -> Expr:
783801
"""Repeats the ``string`` to ``n`` times."""
784802
return Expr(f.repeat(string.expr, n.expr))

python/datafusion/udf.py

Lines changed: 99 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -621,31 +621,48 @@ def __call__(self, *args: Expr) -> Expr:
621621
args_raw = [arg.expr for arg in args]
622622
return Expr(self._udwf.__call__(*args_raw))
623623

624+
@overload
625+
@staticmethod
626+
def udwf(
627+
input_types: pa.DataType | list[pa.DataType],
628+
return_type: pa.DataType,
629+
volatility: Volatility | str,
630+
name: Optional[str] = None,
631+
) -> Callable[..., WindowUDF]: ...
632+
633+
@overload
624634
@staticmethod
625635
def udwf(
626636
func: Callable[[], WindowEvaluator],
627637
input_types: pa.DataType | list[pa.DataType],
628638
return_type: pa.DataType,
629639
volatility: Volatility | str,
630640
name: Optional[str] = None,
631-
) -> WindowUDF:
632-
"""Create a new User-Defined Window Function.
641+
) -> WindowUDF: ...
633642

634-
If your :py:class:`WindowEvaluator` can be instantiated with no arguments, you
635-
can simply pass it's type as ``func``. If you need to pass additional arguments
636-
to it's constructor, you can define a lambda or a factory method. During runtime
637-
the :py:class:`WindowEvaluator` will be constructed for every instance in
638-
which this UDWF is used. The following examples are all valid.
643+
@staticmethod
644+
def udwf(*args: Any, **kwargs: Any): # noqa: D417
645+
"""Create a new User-Defined Window Function (UDWF).
639646
640-
.. code-block:: python
647+
This class can be used both as a **function** and as a **decorator**.
648+
649+
Usage:
650+
- **As a function**: Call `udwf(func, input_types, return_type, volatility,
651+
name)`.
652+
- **As a decorator**: Use `@udwf(input_types, return_type, volatility,
653+
name)`. When using `udwf` as a decorator, **do not pass `func`
654+
explicitly**.
641655
656+
**Function example:**
657+
```
642658
import pyarrow as pa
643659
644660
class BiasedNumbers(WindowEvaluator):
645661
def __init__(self, start: int = 0) -> None:
646662
self.start = start
647663
648-
def evaluate_all(self, values: list[pa.Array], num_rows: int) -> pa.Array:
664+
def evaluate_all(self, values: list[pa.Array],
665+
num_rows: int) -> pa.Array:
649666
return pa.array([self.start + i for i in range(num_rows)])
650667
651668
def bias_10() -> BiasedNumbers:
@@ -655,35 +672,93 @@ def bias_10() -> BiasedNumbers:
655672
udwf2 = udwf(bias_10, pa.int64(), pa.int64(), "immutable")
656673
udwf3 = udwf(lambda: BiasedNumbers(20), pa.int64(), pa.int64(), "immutable")
657674
675+
```
676+
677+
**Decorator example:**
678+
```
679+
@udwf(pa.int64(), pa.int64(), "immutable")
680+
def biased_numbers() -> BiasedNumbers:
681+
return BiasedNumbers(10)
682+
```
683+
658684
Args:
659-
func: A callable to create the window function.
660-
input_types: The data types of the arguments to ``func``.
685+
func: **Only needed when calling as a function. Skip this argument when
686+
using `udwf` as a decorator.**
687+
input_types: The data types of the arguments.
661688
return_type: The data type of the return value.
662689
volatility: See :py:class:`Volatility` for allowed values.
663-
arguments: A list of arguments to pass in to the __init__ method for accum.
664690
name: A descriptive name for the function.
665691
666692
Returns:
667-
A user-defined window function.
668-
""" # noqa: W505, E501
693+
A user-defined window function that can be used in window function calls.
694+
"""
695+
if args and callable(args[0]):
696+
# Case 1: Used as a function, require the first parameter to be callable
697+
return WindowUDF._create_window_udf(*args, **kwargs)
698+
# Case 2: Used as a decorator with parameters
699+
return WindowUDF._create_window_udf_decorator(*args, **kwargs)
700+
701+
@staticmethod
702+
def _create_window_udf(
703+
func: Callable[[], WindowEvaluator],
704+
input_types: pa.DataType | list[pa.DataType],
705+
return_type: pa.DataType,
706+
volatility: Volatility | str,
707+
name: Optional[str] = None,
708+
) -> WindowUDF:
709+
"""Create a WindowUDF instance from function arguments."""
669710
if not callable(func):
670711
msg = "`func` must be callable."
671712
raise TypeError(msg)
672713
if not isinstance(func(), WindowEvaluator):
673714
msg = "`func` must implement the abstract base class WindowEvaluator"
674715
raise TypeError(msg)
675-
if name is None:
676-
name = func().__class__.__qualname__.lower()
677-
if isinstance(input_types, pa.DataType):
678-
input_types = [input_types]
679-
return WindowUDF(
680-
name=name,
681-
func=func,
682-
input_types=input_types,
683-
return_type=return_type,
684-
volatility=volatility,
716+
717+
name = name or func.__qualname__.lower()
718+
input_types = (
719+
[input_types] if isinstance(input_types, pa.DataType) else input_types
685720
)
686721

722+
return WindowUDF(name, func, input_types, return_type, volatility)
723+
724+
@staticmethod
725+
def _get_default_name(func: Callable) -> str:
726+
"""Get the default name for a function based on its attributes."""
727+
if hasattr(func, "__qualname__"):
728+
return func.__qualname__.lower()
729+
return func.__class__.__name__.lower()
730+
731+
@staticmethod
732+
def _normalize_input_types(
733+
input_types: pa.DataType | list[pa.DataType],
734+
) -> list[pa.DataType]:
735+
"""Convert a single DataType to a list if needed."""
736+
if isinstance(input_types, pa.DataType):
737+
return [input_types]
738+
return input_types
739+
740+
@staticmethod
741+
def _create_window_udf_decorator(
742+
input_types: pa.DataType | list[pa.DataType],
743+
return_type: pa.DataType,
744+
volatility: Volatility | str,
745+
name: Optional[str] = None,
746+
) -> Callable[[Callable[[], WindowEvaluator]], Callable[..., Expr]]:
747+
"""Create a decorator for a WindowUDF."""
748+
749+
def decorator(func: Callable[[], WindowEvaluator]) -> Callable[..., Expr]:
750+
udwf_caller = WindowUDF._create_window_udf(
751+
func, input_types, return_type, volatility, name
752+
)
753+
754+
@functools.wraps(func)
755+
def wrapper(*args: Any, **kwargs: Any) -> Expr:
756+
return udwf_caller(*args, **kwargs)
757+
758+
return wrapper
759+
760+
return decorator
761+
687762

688763
# Convenience exports so we can import instead of treating as
689764
# variables at the package root

python/tests/test_dataframe.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -771,6 +771,16 @@ def test_execution_plan(aggregate_df):
771771
assert rows_returned == 5
772772

773773

774+
@pytest.mark.asyncio
775+
async def test_async_iteration_of_df(aggregate_df):
776+
rows_returned = 0
777+
async for batch in aggregate_df.execute_stream():
778+
assert batch is not None
779+
rows_returned += len(batch.to_pyarrow()[0])
780+
781+
assert rows_returned == 5
782+
783+
774784
def test_repartition(df):
775785
df.repartition(2)
776786

@@ -958,6 +968,18 @@ def test_execute_stream(df):
958968
assert not list(stream) # after one iteration the generator must be exhausted
959969

960970

971+
@pytest.mark.asyncio
972+
async def test_execute_stream_async(df):
973+
stream = df.execute_stream()
974+
batches = [batch async for batch in stream]
975+
976+
assert all(batch is not None for batch in batches)
977+
978+
# After consuming all batches, the stream should be exhausted
979+
remaining_batches = [batch async for batch in stream]
980+
assert not remaining_batches
981+
982+
961983
@pytest.mark.parametrize("schema", [True, False])
962984
def test_execute_stream_to_arrow_table(df, schema):
963985
stream = df.execute_stream()
@@ -974,6 +996,25 @@ def test_execute_stream_to_arrow_table(df, schema):
974996
assert set(pyarrow_table.column_names) == {"a", "b", "c"}
975997

976998

999+
@pytest.mark.asyncio
1000+
@pytest.mark.parametrize("schema", [True, False])
1001+
async def test_execute_stream_to_arrow_table_async(df, schema):
1002+
stream = df.execute_stream()
1003+
1004+
if schema:
1005+
pyarrow_table = pa.Table.from_batches(
1006+
[batch.to_pyarrow() async for batch in stream], schema=df.schema()
1007+
)
1008+
else:
1009+
pyarrow_table = pa.Table.from_batches(
1010+
[batch.to_pyarrow() async for batch in stream]
1011+
)
1012+
1013+
assert isinstance(pyarrow_table, pa.Table)
1014+
assert pyarrow_table.shape == (3, 3)
1015+
assert set(pyarrow_table.column_names) == {"a", "b", "c"}
1016+
1017+
9771018
def test_execute_stream_partitioned(df):
9781019
streams = df.execute_stream_partitioned()
9791020
assert all(batch is not None for stream in streams for batch in stream)
@@ -982,6 +1023,19 @@ def test_execute_stream_partitioned(df):
9821023
) # after one iteration all generators must be exhausted
9831024

9841025

1026+
@pytest.mark.asyncio
1027+
async def test_execute_stream_partitioned_async(df):
1028+
streams = df.execute_stream_partitioned()
1029+
1030+
for stream in streams:
1031+
batches = [batch async for batch in stream]
1032+
assert all(batch is not None for batch in batches)
1033+
1034+
# Ensure the stream is exhausted after iteration
1035+
remaining_batches = [batch async for batch in stream]
1036+
assert not remaining_batches
1037+
1038+
9851039
def test_empty_to_arrow_table(df):
9861040
# Convert empty datafusion dataframe to pyarrow Table
9871041
pyarrow_table = df.limit(0).to_arrow_table()

python/tests/test_functions.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -740,6 +740,10 @@ def test_array_function_obj_tests(stmt, py_expr):
740740
f.regexp_replace(column("a"), literal("(ell|orl)"), literal("-")),
741741
pa.array(["H-o", "W-d", "!"]),
742742
),
743+
(
744+
f.regexp_count(column("a"), literal("(ell|orl)"), literal(1)),
745+
pa.array([1, 1, 0], type=pa.int64()),
746+
),
743747
],
744748
)
745749
def test_string_functions(df, function, expected_result):

0 commit comments

Comments
 (0)